@@ -133,7 +133,7 @@ def main():
133133 args .train .compute_train_steps (args .data .max_seq_len , args .data .train_size )
134134 train_dataloader = build_dataloader (
135135 dataset = train_dataset ,
136- dataloader_type = "streaming " ,
136+ dataloader_type = "native " ,
137137 micro_batch_size = args .train .micro_batch_size ,
138138 global_batch_size = args .train .global_batch_size ,
139139 dataloader_batch_size = args .train .dataloader_batch_size ,
@@ -142,17 +142,14 @@ def main():
142142 rmpad = args .train .rmpad ,
143143 rmpad_with_pos_ids = args .train .rmpad_with_pos_ids ,
144144 bsz_warmup_ratio = args .train .bsz_warmup_ratio ,
145- dyn_bsz_runtime = args .train .dyn_bsz_runtime ,
145+ bsz_warmup_init_mbtoken = args .train .bsz_warmup_init_mbtoken ,
146146 dyn_bsz_margin = args .train .dyn_bsz_margin ,
147147 dyn_bsz_buffer_size = args .train .dyn_bsz_buffer_size ,
148148 collate_fn = None ,
149- bsz_warmup_init_mbtoken = args .train .bsz_warmup_init_mbtoken ,
150- infinity = True ,
151149 num_workers = args .data .num_workers ,
152150 drop_last = args .data .drop_last ,
153151 pin_memory = args .data .pin_memory ,
154152 prefetch_factor = args .data .prefetch_factor ,
155- drop_resume_buffer = args .data .drop_resume_buffer ,
156153 )
157154
158155 logger .info_rank0 ("Prepare model" )
@@ -351,7 +348,7 @@ def test_trainer_saveload_ep8():
351348 "--nnodes=1" ,
352349 "--nproc_per_node=8" ,
353350 "--master_port=4321" ,
354- "tests/utils /test_trainer_saveload.py" ,
351+ "tests/checkpoints /test_trainer_saveload.py" ,
355352 "tests/checkpoints/ep8.yaml" ,
356353 ]
357354 ep8_result = subprocess .run (ep8_command , check = True )
0 commit comments