@@ -148,7 +148,7 @@ def main():
148
148
model_config .fuse_attention_ffn = model_args .fuse_attention_ffn
149
149
model_config .pp_seg_method = training_args .pp_seg_method
150
150
model_config .seq_length = data_args .max_length
151
- model_config .max_sequence_length = training_args .max_seq_length
151
+ model_config .max_sequence_length = training_args .max_seq_len
152
152
model_config .num_nextn_predict_layers = model_args .num_nextn_predict_layers
153
153
logger .info (f"Final model config: { model_config } " )
154
154
logger .info ("Creating model" )
@@ -213,11 +213,11 @@ def neft_post_hook(module, input, output):
213
213
214
214
dataset_config = {
215
215
"tokenizer" : tokenizer ,
216
- "max_seq_len" : training_args .max_seq_length ,
216
+ "max_seq_len" : training_args .max_seq_len ,
217
217
"random_seed" : training_args .seed ,
218
- "num_replicas" : 1 ,
219
- "rank" : 0 ,
220
- "num_samples_each_epoch" : 6000000 ,
218
+ "num_replicas" : training_args . dataset_world_size ,
219
+ "rank" : training_args . dataset_rank ,
220
+ "num_samples_each_epoch" : data_args . num_samples_each_epoch ,
221
221
"random_shuffle" : data_args .random_shuffle ,
222
222
"greedy_intokens" : data_args .greedy_intokens ,
223
223
"packing" : data_args .packing ,
@@ -251,7 +251,7 @@ def neft_post_hook(module, input, output):
251
251
collate_fn ,
252
252
tokenizer = tokenizer ,
253
253
model_args = model_args ,
254
- max_seq_len = training_args .max_seq_length + model_config .num_nextn_predict_layers ,
254
+ max_seq_len = training_args .max_seq_len + model_config .num_nextn_predict_layers ,
255
255
)
256
256
trainer = SFTTrainer (
257
257
model = model ,
0 commit comments