@@ -33,6 +33,7 @@ def tokenize_dataset(
3333 max_seq_length : int ,
3434 seed : int ,
3535 dataset_kwargs : dict | None = None ,
36+ pad_seq_to_mult : int | None = 1 ,
3637):
3738 """
3839 Tokenizes a dataset from the provided path using the specified tokenizer
@@ -45,6 +46,8 @@ def tokenize_dataset(
4546 seed (int): Random seed for shuffling the dataset.
4647 dataset_kwargs (dict | None): Additional keyword arguments to pass to create_sft_dataset.
4748 Can include 'chat', 'use_hf_tokenizer_chat_template', 'tool_schemas', etc.
49+ pad_seq_to_mult (int | None): Optional multiple to pad each sequence to during packing
50+ preparation (e.g., set to 2 * context_parallel_size for THD CP).
4851
4952 Returns:
5053 np.ndarray: A NumPy array containing the tokenized data.
@@ -66,15 +69,56 @@ def tokenize_dataset(
6669 if hasattr (tokenizer , "_tokenizer" ):
6770 tokenizer ._tokenizer .chat_template = chat_template
6871
72+ if pad_seq_to_mult is not None and pad_seq_to_mult <= 0 :
73+ raise ValueError ("pad_seq_to_mult must be a positive integer when provided." )
74+
75+ # Keep the historical minimum of 16 unless a larger multiple is requested.
76+ pad_seq_length_to_mult = 1 if pad_seq_to_mult is None else max (1 , pad_seq_to_mult )
77+
6978 dataset = create_sft_dataset (
7079 path = path ,
7180 tokenizer = tokenizer ,
7281 seq_length = max_seq_length ,
7382 seed = seed ,
7483 is_test = True ,
84+ pad_seq_length_to_mult = pad_seq_length_to_mult ,
7585 ** dataset_kwargs ,
7686 )
77- return np .array ([dataset [i ] for i in range (len (dataset ))])
87+
88+ pad_id = dataset .tokenizer .eod
89+ pad_seq_length_to_mult = dataset .pad_seq_length_to_mult
90+ max_seq_length = dataset .max_seq_length
91+ dataset = np .array ([dataset [i ] for i in range (len (dataset ))])
92+
93+ if pad_seq_to_mult > 1 :
94+
95+ def pre_pad_dataset (data , max_seq_length , max_length_to_pad , pad_id ):
96+ """
97+ Pad each individual data point to the length of max_length_to_pad.
98+ This keeps packed samples divisible by the requested multiple (used for CP/THD).
99+ """
100+ assert max_seq_length >= max_length_to_pad
101+ for key , val in data .items ():
102+ if key in {"input_ids" , "context_ids" }:
103+ if len (val ) <= max_length_to_pad :
104+ # input_ids are truncated by 1 for labels; add 1 extra pad token
105+ val = val + [pad_id ] * (max_length_to_pad - len (val ) + 1 )
106+ elif len (val ) > max_seq_length :
107+ logging .info (
108+ "Sequence length %d is larger than max_seq_length %d; truncating for packing." ,
109+ len (val ),
110+ max_seq_length ,
111+ )
112+ val = val [:max_seq_length ]
113+ data [key ] = val
114+ return
115+
116+ ceil_to_nearest = lambda n , m : (n + m - 1 ) // m * m
117+ for data in dataset :
118+ max_length_to_pad = min (max_seq_length , ceil_to_nearest (len (data ["input_ids" ]), pad_seq_length_to_mult ))
119+ pre_pad_dataset (data , max_seq_length , max_length_to_pad , pad_id )
120+
121+ return dataset
78122
79123
80124def prepare_packed_sequence_data (
@@ -87,6 +131,7 @@ def prepare_packed_sequence_data(
87131 seed : int | None = 0 ,
88132 packing_algorithm : str = "first_fit_shuffle" ,
89133 dataset_kwargs : dict | None = None ,
134+ pad_seq_to_mult : int | None = 1 ,
90135):
91136 """
92137 Prepares a packed sequence dataset from a given input file and saves it to an output file.
@@ -103,12 +148,21 @@ def prepare_packed_sequence_data(
103148 currently supports "first_fit_shuffle" and "first_fit_decreasing".
104149 dataset_kwargs (dict | None): Additional keyword arguments to pass to create_sft_dataset.
105150 Enables packing with chat templates, tool schemas, etc.
151+ pad_seq_to_mult (int | None): Optional multiple to pad each sequence to during packing
152+ preparation (e.g., set to 2 * context_parallel_size for THD CP).
106153
107154 Returns:
108155 None: Saves the packed sequence data to the specified output path.
109156 """
110157 logger .info (f"Preparing packed sequence from { input_path } " )
111- dataset = tokenize_dataset (input_path , tokenizer , max_seq_length , seed , dataset_kwargs )
158+ dataset = tokenize_dataset (
159+ input_path ,
160+ tokenizer ,
161+ max_seq_length ,
162+ seed ,
163+ dataset_kwargs ,
164+ pad_seq_to_mult = pad_seq_to_mult ,
165+ )
112166 sequences , histogram = create_hist (dataset , max_seq_length )
113167
114168 assignments , packing_metadata = create_packing_strategy (histogram , packed_sequence_size , packing_algorithm )
@@ -185,6 +239,11 @@ class PackedSequenceSpecs:
185239 """
186240 If True, pad cu_seqlens to a constant size, which is required for use with cudagraphs.
187241 """
242+ pad_seq_to_mult : int | None = 1
243+ """
244+ Optional multiple to pad each sample to when generating packed datasets.
245+ For THD/context parallel, set to (context_parallel_size * 2) to keep samples divisible.
246+ """
188247
189248 def __post_init__ (self ):
190249 if self .packed_train_data_path is not None :
@@ -212,3 +271,6 @@ def __post_init__(self):
212271 assert self .packed_val_data_path .exists (), (
213272 f"packed validation data file does not exist: { self .packed_val_data_path } "
214273 )
274+
275+ if self .pad_seq_to_mult is not None and self .pad_seq_to_mult <= 0 :
276+ raise ValueError ("pad_seq_to_mult must be a positive integer when provided." )
0 commit comments