Skip to content

Commit d9e58fa

Browse files
authored
[Refactor] split packing utils function to ease reuse in RL (#1443)
1 parent 6583c4d commit d9e58fa

File tree

2 files changed

+159
-146
lines changed

2 files changed

+159
-146
lines changed

examples/v1/scripts/run_rl.sh

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ export PYTHONPATH=$(pwd):$PYTHONPATH
2727

2828
# ray 环境变量
2929
export MASTER_PORT=6000
30-
export WORLD_SIZE=$NODE_COUNT
31-
export RANK=$NODE_RANK
30+
export WORLD_SIZE=${NODE_COUNT:-"1"}
31+
export RANK=${NODE_RANK:-"0"}
3232
export RAY_MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
3333
export RAY_RANK=${RANK:-0} # 0 代表主节点, >0 代表工作节点
3434
export RAY_HEAD_PORT=${RAY_HEAD_PORT:-"6379"}
@@ -66,11 +66,18 @@ current_time=$(date "+%m%d%H")
6666
# 取模型路径的最后一级作为model_name,取数据路径的倒数第二级作为data_name
6767
model_dir_name=$(basename "$MODEL_PATH")
6868
data_dir_name=$(basename "$(dirname "$DATA_PATH")")
69-
DIR=$(pwd)
70-
export WORK_DIR="${DIR}/work_dirs/${model_dir_name}_${data_dir_name}_${infer_backend_lower}"
69+
70+
if [ "x$WORK_DIR" = "x" ]; then
71+
DIR=$(pwd)
72+
export WORK_DIR="${DIR}/work_dirs/${model_dir_name}_${data_dir_name}_${infer_backend_lower}"
73+
else
74+
export WORK_DIR=$WORK_DIR
75+
fi
76+
echo "WORK_DIR: $WORK_DIR"
7177
if [ ! -d "$WORK_DIR" ]; then
7278
mkdir -p "$WORK_DIR"
7379
fi
80+
7481
export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt"
7582
if [ "$ACCELERATOR" = "GPU" ]; then
7683
# TODO: support NPU RL Memory Monitor
@@ -86,6 +93,7 @@ elif [ "$ACCELERATOR" = "NPU" ]; then
8693
total_cpus=$((node_count * 256))
8794
fi
8895

96+
WORK_DIR=$(realpath "$WORK_DIR")
8997
if [ "$RAY_RANK" -eq 0 ]; then
9098
rm -rf /tmp/ray_log
9199
export RAY_LOG_DIR="${WORK_DIR}/ray_${current_time}/"

xtuner/v1/datasets/packing.py

Lines changed: 147 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,41 @@
2121
logger = get_logger()
2222

2323

24+
def get_pack_infos_by_soft_split(inds: list[int], dataset_id: int, num_tokens: np.ndarray, pack_max_length: int):
25+
item_buffer: list[int] = []
26+
length_buffer: list[int] = []
27+
longest = 0
28+
29+
pack_infos = []
30+
for shfl_i in inds:
31+
if num_tokens[shfl_i] + sum(length_buffer) <= pack_max_length:
32+
item_buffer.append(shfl_i)
33+
length_buffer.append(num_tokens[shfl_i])
34+
longest = max(longest, num_tokens[shfl_i])
35+
else:
36+
if len(item_buffer) > 0:
37+
info = {
38+
"dataset_id": dataset_id,
39+
"indices": item_buffer,
40+
"longest": int(longest),
41+
}
42+
pack_infos.append(info)
43+
44+
item_buffer = [shfl_i]
45+
length_buffer = [num_tokens[shfl_i]]
46+
longest = num_tokens[shfl_i]
47+
48+
if len(item_buffer) > 0:
49+
info = {
50+
"dataset_id": dataset_id,
51+
"indices": item_buffer,
52+
"longest": int(longest),
53+
}
54+
55+
pack_infos.append(info)
56+
return pack_infos
57+
58+
2459
class _LegacySoftPackDataset(torch.utils.data.Dataset):
2560
def __init__(self, datasets, pack_max_length=2048, global_pack=False, seed: int | None = None):
2661
self.random = random.Random()
@@ -52,37 +87,7 @@ def get_pack_infos(self, dataset, dataset_id, num_tokens):
5287
inds = list(range(len(dataset)))
5388
self.random.shuffle(inds)
5489

55-
item_buffer = []
56-
length_buffer = []
57-
longest = 0
58-
59-
pack_infos = []
60-
for shfl_i in inds:
61-
if num_tokens[shfl_i] + sum(length_buffer) <= self.pack_max_length:
62-
item_buffer.append(shfl_i)
63-
length_buffer.append(num_tokens[shfl_i])
64-
longest = max(longest, num_tokens[shfl_i])
65-
else:
66-
if len(item_buffer) > 0:
67-
info = {
68-
"dataset_id": dataset_id,
69-
"indices": item_buffer,
70-
"longest": int(longest),
71-
}
72-
pack_infos.append(info)
73-
74-
item_buffer = [shfl_i]
75-
length_buffer = [num_tokens[shfl_i]]
76-
longest = num_tokens[shfl_i]
77-
78-
if len(item_buffer) > 0:
79-
info = {
80-
"dataset_id": dataset_id,
81-
"indices": item_buffer,
82-
"longest": int(longest),
83-
}
84-
85-
pack_infos.append(info)
90+
pack_infos = get_pack_infos_by_soft_split(inds, dataset_id, num_tokens, self.pack_max_length)
8691

8792
pack_infos = Dataset.from_list(pack_infos)
8893

@@ -228,6 +233,62 @@ def get_pack_chunk_infos(
228233
return pack_infos
229234

230235

236+
def get_pack_infos_by_expand_soft_split(
237+
inds: list[int],
238+
dataset_id: int,
239+
num_tokens: np.ndarray,
240+
pack_max_length: int,
241+
pack_workers: int = 8,
242+
pack_chunk_size: int = 10000,
243+
flash_attn_block_size: int = 128,
244+
pack_len_type: str = "total_block",
245+
pack_extra_buffer_size: int = 1000,
246+
):
247+
if pack_workers <= 1:
248+
pack_infos = []
249+
for i in range(0, len(inds), pack_chunk_size):
250+
chunk_inds = inds[i : i + pack_chunk_size]
251+
chunk_pack_infos = get_pack_chunk_infos(
252+
chunk_inds,
253+
dataset_id,
254+
pack_max_length,
255+
flash_attn_block_size,
256+
pack_len_type,
257+
pack_extra_buffer_size,
258+
num_tokens,
259+
)
260+
pack_infos.extend(chunk_pack_infos)
261+
else:
262+
chunks_inds = [inds[i : i + pack_chunk_size] for i in range(0, len(inds), pack_chunk_size)]
263+
264+
shm = shared_memory.SharedMemory(create=True, size=num_tokens.nbytes)
265+
shm_array = np.ndarray(num_tokens.shape, dtype=num_tokens.dtype, buffer=shm.buf)
266+
np.copyto(shm_array, num_tokens)
267+
268+
mp_context = multiprocessing.get_context("fork")
269+
process_chunk_with_args = partial(
270+
get_pack_chunk_infos,
271+
dataset_id=dataset_id,
272+
target=pack_max_length,
273+
flash_attn_block_size=flash_attn_block_size,
274+
pack_len_type=pack_len_type,
275+
pack_extra_buffer_size=pack_extra_buffer_size,
276+
shm_name=shm.name,
277+
shape=num_tokens.shape,
278+
dtype=num_tokens.dtype,
279+
)
280+
with ProcessPoolExecutor(max_workers=pack_workers, mp_context=mp_context) as executor:
281+
results = list(tqdm(executor.map(process_chunk_with_args, chunks_inds)))
282+
283+
pack_infos = []
284+
for result in results:
285+
pack_infos.extend(result)
286+
287+
shm.close()
288+
shm.unlink()
289+
return pack_infos
290+
291+
231292
class ExpandSoftPackDataset(_LegacySoftPackDataset):
232293
def __init__(
233294
self,
@@ -259,65 +320,9 @@ def __init__(
259320
seed=seed,
260321
)
261322

262-
@staticmethod
263-
def get_pack_infos_staticmethod(
264-
inds: list[int],
265-
dataset_id: int,
266-
num_tokens: np.ndarray,
267-
pack_max_length: int,
268-
pack_workers: int,
269-
pack_chunk_size: int,
270-
flash_attn_block_size: int,
271-
pack_len_type: str,
272-
pack_extra_buffer_size: int,
273-
):
274-
if pack_workers <= 1:
275-
pack_infos = []
276-
for i in range(0, len(inds), pack_chunk_size):
277-
chunk_inds = inds[i : i + pack_chunk_size]
278-
chunk_pack_infos = get_pack_chunk_infos(
279-
chunk_inds,
280-
dataset_id,
281-
pack_max_length,
282-
flash_attn_block_size,
283-
pack_len_type,
284-
pack_extra_buffer_size,
285-
num_tokens,
286-
)
287-
pack_infos.extend(chunk_pack_infos)
288-
else:
289-
chunks_inds = [inds[i : i + pack_chunk_size] for i in range(0, len(inds), pack_chunk_size)]
290-
291-
shm = shared_memory.SharedMemory(create=True, size=num_tokens.nbytes)
292-
shm_array = np.ndarray(num_tokens.shape, dtype=num_tokens.dtype, buffer=shm.buf)
293-
np.copyto(shm_array, num_tokens)
294-
295-
mp_context = multiprocessing.get_context("fork")
296-
process_chunk_with_args = partial(
297-
get_pack_chunk_infos,
298-
dataset_id=dataset_id,
299-
target=pack_max_length,
300-
flash_attn_block_size=flash_attn_block_size,
301-
pack_len_type=pack_len_type,
302-
pack_extra_buffer_size=pack_extra_buffer_size,
303-
shm_name=shm.name,
304-
shape=num_tokens.shape,
305-
dtype=num_tokens.dtype,
306-
)
307-
with ProcessPoolExecutor(max_workers=pack_workers, mp_context=mp_context) as executor:
308-
results = list(tqdm(executor.map(process_chunk_with_args, chunks_inds)))
309-
310-
pack_infos = []
311-
for result in results:
312-
pack_infos.extend(result)
313-
314-
shm.close()
315-
shm.unlink()
316-
return pack_infos
317-
318323
def get_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray):
319324
inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist()
320-
pack_infos = self.get_pack_infos_staticmethod(
325+
pack_infos = get_pack_infos_by_expand_soft_split(
321326
inds,
322327
dataset_id,
323328
num_tokens,
@@ -408,6 +413,57 @@ def _hard_pack_chunk(
408413
return out
409414

410415

416+
def get_pack_infos_by_hard_split(
417+
inds: list[int], dataset_id: int, num_tokens: np.ndarray, pack_max_length: int, pack_workers: int = 1
418+
):
419+
# number of packed samples
420+
shfl_inds = inds
421+
num_packed_samples = int(num_tokens.sum() / pack_max_length)
422+
423+
# shuffled cumulative lengths with leading 0
424+
shfl_lens: np.ndarray = np.take(num_tokens, shfl_inds)
425+
shfl_cu_lens = np.cumsum(shfl_lens, dtype=np.int64)
426+
shfl_cu_lens = np.insert(shfl_cu_lens, 0, 0).astype(np.int64, copy=False)
427+
428+
# shared memory for cu and inds
429+
cu_arr = np.asarray(shfl_cu_lens, dtype=np.int64).reshape(-1)
430+
inds_arr = np.asarray(shfl_inds, dtype=np.int64).reshape(-1)
431+
432+
# chunk tasks
433+
chunk_size = 10000
434+
i_all = list(range(num_packed_samples))
435+
chunks = [i_all[i : i + chunk_size] for i in range(0, len(i_all), chunk_size)]
436+
437+
pack_infos_list = []
438+
439+
if pack_workers > 1:
440+
# Use fork to inherit read-only arrays; no extra shared memory copy needed
441+
mp_context = multiprocessing.get_context("fork")
442+
fn = partial(
443+
_hard_pack_chunk_core,
444+
dataset_id=dataset_id,
445+
pack_max_length=pack_max_length,
446+
cu=cu_arr,
447+
inds_arr=inds_arr,
448+
)
449+
with ProcessPoolExecutor(max_workers=pack_workers, mp_context=mp_context) as ex:
450+
for res in tqdm(ex.map(fn, chunks), total=len(chunks)):
451+
pack_infos_list.extend(res)
452+
else:
453+
# single-process path, reuse the same core
454+
for i_chunk in tqdm(chunks, total=len(chunks)):
455+
pack_infos_list.extend(
456+
_hard_pack_chunk_core(
457+
i_chunk,
458+
dataset_id=dataset_id,
459+
pack_max_length=pack_max_length,
460+
cu=cu_arr,
461+
inds_arr=inds_arr,
462+
)
463+
)
464+
return pack_infos_list
465+
466+
411467
class HardPackDataset(_LegacySoftPackDataset):
412468
def __init__(
413469
self, datasets, pack_max_length=2048, global_pack=False, seed: int | None = None, pack_workers: int = 1
@@ -420,63 +476,12 @@ def __init__(
420476
seed=seed,
421477
)
422478

423-
@staticmethod
424-
def get_pack_infos_staticmethod(
425-
inds: list, dataset_id: int, num_tokens: np.ndarray, pack_max_length: int, pack_workers: int
426-
):
427-
# number of packed samples
428-
shfl_inds = inds
429-
num_packed_samples = int(num_tokens.sum() / pack_max_length)
430-
431-
# shuffled cumulative lengths with leading 0
432-
shfl_lens: np.ndarray = np.take(num_tokens, shfl_inds)
433-
shfl_cu_lens = np.cumsum(shfl_lens, dtype=np.int64)
434-
shfl_cu_lens = np.insert(shfl_cu_lens, 0, 0).astype(np.int64, copy=False)
435-
436-
# shared memory for cu and inds
437-
cu_arr = np.asarray(shfl_cu_lens, dtype=np.int64).reshape(-1)
438-
inds_arr = np.asarray(shfl_inds, dtype=np.int64).reshape(-1)
439-
440-
# chunk tasks
441-
chunk_size = 10000
442-
i_all = list(range(num_packed_samples))
443-
chunks = [i_all[i : i + chunk_size] for i in range(0, len(i_all), chunk_size)]
444-
445-
pack_infos_list = []
446-
447-
if pack_workers > 1:
448-
# Use fork to inherit read-only arrays; no extra shared memory copy needed
449-
mp_context = multiprocessing.get_context("fork")
450-
fn = partial(
451-
_hard_pack_chunk_core,
452-
dataset_id=dataset_id,
453-
pack_max_length=pack_max_length,
454-
cu=cu_arr,
455-
inds_arr=inds_arr,
456-
)
457-
with ProcessPoolExecutor(max_workers=pack_workers, mp_context=mp_context) as ex:
458-
for res in tqdm(ex.map(fn, chunks), total=len(chunks)):
459-
pack_infos_list.extend(res)
460-
else:
461-
# single-process path, reuse the same core
462-
for i_chunk in tqdm(chunks, total=len(chunks)):
463-
pack_infos_list.extend(
464-
_hard_pack_chunk_core(
465-
i_chunk,
466-
dataset_id=dataset_id,
467-
pack_max_length=pack_max_length,
468-
cu=cu_arr,
469-
inds_arr=inds_arr,
470-
)
471-
)
472-
return pack_infos_list
473-
474479
def get_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray):
475480
# shuffled indices
476481
inds = list(range(len(dataset)))
477482
self.random.shuffle(inds)
478483

479-
pack_infos_list = self.get_pack_infos_staticmethod(
484+
pack_infos_list = get_pack_infos_by_hard_split(
480485
inds, dataset_id, num_tokens, pack_max_length=self.pack_max_length, pack_workers=self.pack_workers
481486
)
482487

@@ -631,7 +636,7 @@ def get_hard_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.nd
631636
# shuffled indices
632637
inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist()
633638

634-
pack_infos_list = HardPackDataset.get_pack_infos_staticmethod(
639+
pack_infos_list = get_pack_infos_by_hard_split(
635640
inds, dataset_id, num_tokens, pack_max_length=self.pack_max_length, pack_workers=self.pack_workers
636641
)
637642
return pack_infos_list
@@ -640,7 +645,7 @@ def get_soft_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.nd
640645
# shuffled indices
641646
inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist()
642647

643-
pack_infos_list = ExpandSoftPackDataset.get_pack_infos_staticmethod(
648+
pack_infos_list = get_pack_infos_by_expand_soft_split(
644649
inds,
645650
dataset_id,
646651
num_tokens,

0 commit comments

Comments
 (0)