2121logger = get_logger ()
2222
2323
24+ def get_pack_infos_by_soft_split (inds : list [int ], dataset_id : int , num_tokens : np .ndarray , pack_max_length : int ):
25+ item_buffer : list [int ] = []
26+ length_buffer : list [int ] = []
27+ longest = 0
28+
29+ pack_infos = []
30+ for shfl_i in inds :
31+ if num_tokens [shfl_i ] + sum (length_buffer ) <= pack_max_length :
32+ item_buffer .append (shfl_i )
33+ length_buffer .append (num_tokens [shfl_i ])
34+ longest = max (longest , num_tokens [shfl_i ])
35+ else :
36+ if len (item_buffer ) > 0 :
37+ info = {
38+ "dataset_id" : dataset_id ,
39+ "indices" : item_buffer ,
40+ "longest" : int (longest ),
41+ }
42+ pack_infos .append (info )
43+
44+ item_buffer = [shfl_i ]
45+ length_buffer = [num_tokens [shfl_i ]]
46+ longest = num_tokens [shfl_i ]
47+
48+ if len (item_buffer ) > 0 :
49+ info = {
50+ "dataset_id" : dataset_id ,
51+ "indices" : item_buffer ,
52+ "longest" : int (longest ),
53+ }
54+
55+ pack_infos .append (info )
56+ return pack_infos
57+
58+
2459class _LegacySoftPackDataset (torch .utils .data .Dataset ):
2560 def __init__ (self , datasets , pack_max_length = 2048 , global_pack = False , seed : int | None = None ):
2661 self .random = random .Random ()
@@ -52,37 +87,7 @@ def get_pack_infos(self, dataset, dataset_id, num_tokens):
5287 inds = list (range (len (dataset )))
5388 self .random .shuffle (inds )
5489
55- item_buffer = []
56- length_buffer = []
57- longest = 0
58-
59- pack_infos = []
60- for shfl_i in inds :
61- if num_tokens [shfl_i ] + sum (length_buffer ) <= self .pack_max_length :
62- item_buffer .append (shfl_i )
63- length_buffer .append (num_tokens [shfl_i ])
64- longest = max (longest , num_tokens [shfl_i ])
65- else :
66- if len (item_buffer ) > 0 :
67- info = {
68- "dataset_id" : dataset_id ,
69- "indices" : item_buffer ,
70- "longest" : int (longest ),
71- }
72- pack_infos .append (info )
73-
74- item_buffer = [shfl_i ]
75- length_buffer = [num_tokens [shfl_i ]]
76- longest = num_tokens [shfl_i ]
77-
78- if len (item_buffer ) > 0 :
79- info = {
80- "dataset_id" : dataset_id ,
81- "indices" : item_buffer ,
82- "longest" : int (longest ),
83- }
84-
85- pack_infos .append (info )
90+ pack_infos = get_pack_infos_by_soft_split (inds , dataset_id , num_tokens , self .pack_max_length )
8691
8792 pack_infos = Dataset .from_list (pack_infos )
8893
@@ -228,6 +233,62 @@ def get_pack_chunk_infos(
228233 return pack_infos
229234
230235
236+ def get_pack_infos_by_expand_soft_split (
237+ inds : list [int ],
238+ dataset_id : int ,
239+ num_tokens : np .ndarray ,
240+ pack_max_length : int ,
241+ pack_workers : int = 8 ,
242+ pack_chunk_size : int = 10000 ,
243+ flash_attn_block_size : int = 128 ,
244+ pack_len_type : str = "total_block" ,
245+ pack_extra_buffer_size : int = 1000 ,
246+ ):
247+ if pack_workers <= 1 :
248+ pack_infos = []
249+ for i in range (0 , len (inds ), pack_chunk_size ):
250+ chunk_inds = inds [i : i + pack_chunk_size ]
251+ chunk_pack_infos = get_pack_chunk_infos (
252+ chunk_inds ,
253+ dataset_id ,
254+ pack_max_length ,
255+ flash_attn_block_size ,
256+ pack_len_type ,
257+ pack_extra_buffer_size ,
258+ num_tokens ,
259+ )
260+ pack_infos .extend (chunk_pack_infos )
261+ else :
262+ chunks_inds = [inds [i : i + pack_chunk_size ] for i in range (0 , len (inds ), pack_chunk_size )]
263+
264+ shm = shared_memory .SharedMemory (create = True , size = num_tokens .nbytes )
265+ shm_array = np .ndarray (num_tokens .shape , dtype = num_tokens .dtype , buffer = shm .buf )
266+ np .copyto (shm_array , num_tokens )
267+
268+ mp_context = multiprocessing .get_context ("fork" )
269+ process_chunk_with_args = partial (
270+ get_pack_chunk_infos ,
271+ dataset_id = dataset_id ,
272+ target = pack_max_length ,
273+ flash_attn_block_size = flash_attn_block_size ,
274+ pack_len_type = pack_len_type ,
275+ pack_extra_buffer_size = pack_extra_buffer_size ,
276+ shm_name = shm .name ,
277+ shape = num_tokens .shape ,
278+ dtype = num_tokens .dtype ,
279+ )
280+ with ProcessPoolExecutor (max_workers = pack_workers , mp_context = mp_context ) as executor :
281+ results = list (tqdm (executor .map (process_chunk_with_args , chunks_inds )))
282+
283+ pack_infos = []
284+ for result in results :
285+ pack_infos .extend (result )
286+
287+ shm .close ()
288+ shm .unlink ()
289+ return pack_infos
290+
291+
231292class ExpandSoftPackDataset (_LegacySoftPackDataset ):
232293 def __init__ (
233294 self ,
@@ -259,65 +320,9 @@ def __init__(
259320 seed = seed ,
260321 )
261322
262- @staticmethod
263- def get_pack_infos_staticmethod (
264- inds : list [int ],
265- dataset_id : int ,
266- num_tokens : np .ndarray ,
267- pack_max_length : int ,
268- pack_workers : int ,
269- pack_chunk_size : int ,
270- flash_attn_block_size : int ,
271- pack_len_type : str ,
272- pack_extra_buffer_size : int ,
273- ):
274- if pack_workers <= 1 :
275- pack_infos = []
276- for i in range (0 , len (inds ), pack_chunk_size ):
277- chunk_inds = inds [i : i + pack_chunk_size ]
278- chunk_pack_infos = get_pack_chunk_infos (
279- chunk_inds ,
280- dataset_id ,
281- pack_max_length ,
282- flash_attn_block_size ,
283- pack_len_type ,
284- pack_extra_buffer_size ,
285- num_tokens ,
286- )
287- pack_infos .extend (chunk_pack_infos )
288- else :
289- chunks_inds = [inds [i : i + pack_chunk_size ] for i in range (0 , len (inds ), pack_chunk_size )]
290-
291- shm = shared_memory .SharedMemory (create = True , size = num_tokens .nbytes )
292- shm_array = np .ndarray (num_tokens .shape , dtype = num_tokens .dtype , buffer = shm .buf )
293- np .copyto (shm_array , num_tokens )
294-
295- mp_context = multiprocessing .get_context ("fork" )
296- process_chunk_with_args = partial (
297- get_pack_chunk_infos ,
298- dataset_id = dataset_id ,
299- target = pack_max_length ,
300- flash_attn_block_size = flash_attn_block_size ,
301- pack_len_type = pack_len_type ,
302- pack_extra_buffer_size = pack_extra_buffer_size ,
303- shm_name = shm .name ,
304- shape = num_tokens .shape ,
305- dtype = num_tokens .dtype ,
306- )
307- with ProcessPoolExecutor (max_workers = pack_workers , mp_context = mp_context ) as executor :
308- results = list (tqdm (executor .map (process_chunk_with_args , chunks_inds )))
309-
310- pack_infos = []
311- for result in results :
312- pack_infos .extend (result )
313-
314- shm .close ()
315- shm .unlink ()
316- return pack_infos
317-
318323 def get_pack_infos (self , dataset : Sized , dataset_id : int , num_tokens : np .ndarray ):
319324 inds = torch .randperm (len (dataset ), generator = self .torch_random_generator ).tolist ()
320- pack_infos = self . get_pack_infos_staticmethod (
325+ pack_infos = get_pack_infos_by_expand_soft_split (
321326 inds ,
322327 dataset_id ,
323328 num_tokens ,
@@ -408,6 +413,57 @@ def _hard_pack_chunk(
408413 return out
409414
410415
416+ def get_pack_infos_by_hard_split (
417+ inds : list [int ], dataset_id : int , num_tokens : np .ndarray , pack_max_length : int , pack_workers : int = 1
418+ ):
419+ # number of packed samples
420+ shfl_inds = inds
421+ num_packed_samples = int (num_tokens .sum () / pack_max_length )
422+
423+ # shuffled cumulative lengths with leading 0
424+ shfl_lens : np .ndarray = np .take (num_tokens , shfl_inds )
425+ shfl_cu_lens = np .cumsum (shfl_lens , dtype = np .int64 )
426+ shfl_cu_lens = np .insert (shfl_cu_lens , 0 , 0 ).astype (np .int64 , copy = False )
427+
428+ # shared memory for cu and inds
429+ cu_arr = np .asarray (shfl_cu_lens , dtype = np .int64 ).reshape (- 1 )
430+ inds_arr = np .asarray (shfl_inds , dtype = np .int64 ).reshape (- 1 )
431+
432+ # chunk tasks
433+ chunk_size = 10000
434+ i_all = list (range (num_packed_samples ))
435+ chunks = [i_all [i : i + chunk_size ] for i in range (0 , len (i_all ), chunk_size )]
436+
437+ pack_infos_list = []
438+
439+ if pack_workers > 1 :
440+ # Use fork to inherit read-only arrays; no extra shared memory copy needed
441+ mp_context = multiprocessing .get_context ("fork" )
442+ fn = partial (
443+ _hard_pack_chunk_core ,
444+ dataset_id = dataset_id ,
445+ pack_max_length = pack_max_length ,
446+ cu = cu_arr ,
447+ inds_arr = inds_arr ,
448+ )
449+ with ProcessPoolExecutor (max_workers = pack_workers , mp_context = mp_context ) as ex :
450+ for res in tqdm (ex .map (fn , chunks ), total = len (chunks )):
451+ pack_infos_list .extend (res )
452+ else :
453+ # single-process path, reuse the same core
454+ for i_chunk in tqdm (chunks , total = len (chunks )):
455+ pack_infos_list .extend (
456+ _hard_pack_chunk_core (
457+ i_chunk ,
458+ dataset_id = dataset_id ,
459+ pack_max_length = pack_max_length ,
460+ cu = cu_arr ,
461+ inds_arr = inds_arr ,
462+ )
463+ )
464+ return pack_infos_list
465+
466+
411467class HardPackDataset (_LegacySoftPackDataset ):
412468 def __init__ (
413469 self , datasets , pack_max_length = 2048 , global_pack = False , seed : int | None = None , pack_workers : int = 1
@@ -420,63 +476,12 @@ def __init__(
420476 seed = seed ,
421477 )
422478
423- @staticmethod
424- def get_pack_infos_staticmethod (
425- inds : list , dataset_id : int , num_tokens : np .ndarray , pack_max_length : int , pack_workers : int
426- ):
427- # number of packed samples
428- shfl_inds = inds
429- num_packed_samples = int (num_tokens .sum () / pack_max_length )
430-
431- # shuffled cumulative lengths with leading 0
432- shfl_lens : np .ndarray = np .take (num_tokens , shfl_inds )
433- shfl_cu_lens = np .cumsum (shfl_lens , dtype = np .int64 )
434- shfl_cu_lens = np .insert (shfl_cu_lens , 0 , 0 ).astype (np .int64 , copy = False )
435-
436- # shared memory for cu and inds
437- cu_arr = np .asarray (shfl_cu_lens , dtype = np .int64 ).reshape (- 1 )
438- inds_arr = np .asarray (shfl_inds , dtype = np .int64 ).reshape (- 1 )
439-
440- # chunk tasks
441- chunk_size = 10000
442- i_all = list (range (num_packed_samples ))
443- chunks = [i_all [i : i + chunk_size ] for i in range (0 , len (i_all ), chunk_size )]
444-
445- pack_infos_list = []
446-
447- if pack_workers > 1 :
448- # Use fork to inherit read-only arrays; no extra shared memory copy needed
449- mp_context = multiprocessing .get_context ("fork" )
450- fn = partial (
451- _hard_pack_chunk_core ,
452- dataset_id = dataset_id ,
453- pack_max_length = pack_max_length ,
454- cu = cu_arr ,
455- inds_arr = inds_arr ,
456- )
457- with ProcessPoolExecutor (max_workers = pack_workers , mp_context = mp_context ) as ex :
458- for res in tqdm (ex .map (fn , chunks ), total = len (chunks )):
459- pack_infos_list .extend (res )
460- else :
461- # single-process path, reuse the same core
462- for i_chunk in tqdm (chunks , total = len (chunks )):
463- pack_infos_list .extend (
464- _hard_pack_chunk_core (
465- i_chunk ,
466- dataset_id = dataset_id ,
467- pack_max_length = pack_max_length ,
468- cu = cu_arr ,
469- inds_arr = inds_arr ,
470- )
471- )
472- return pack_infos_list
473-
474479 def get_pack_infos (self , dataset : Sized , dataset_id : int , num_tokens : np .ndarray ):
475480 # shuffled indices
476481 inds = list (range (len (dataset )))
477482 self .random .shuffle (inds )
478483
479- pack_infos_list = self . get_pack_infos_staticmethod (
484+ pack_infos_list = get_pack_infos_by_hard_split (
480485 inds , dataset_id , num_tokens , pack_max_length = self .pack_max_length , pack_workers = self .pack_workers
481486 )
482487
@@ -631,7 +636,7 @@ def get_hard_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.nd
631636 # shuffled indices
632637 inds = torch .randperm (len (dataset ), generator = self .torch_random_generator ).tolist ()
633638
634- pack_infos_list = HardPackDataset . get_pack_infos_staticmethod (
639+ pack_infos_list = get_pack_infos_by_hard_split (
635640 inds , dataset_id , num_tokens , pack_max_length = self .pack_max_length , pack_workers = self .pack_workers
636641 )
637642 return pack_infos_list
@@ -640,7 +645,7 @@ def get_soft_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.nd
640645 # shuffled indices
641646 inds = torch .randperm (len (dataset ), generator = self .torch_random_generator ).tolist ()
642647
643- pack_infos_list = ExpandSoftPackDataset . get_pack_infos_staticmethod (
648+ pack_infos_list = get_pack_infos_by_expand_soft_split (
644649 inds ,
645650 dataset_id ,
646651 num_tokens ,
0 commit comments