Skip to content

Commit 3f1ae7b

Browse files
authored
[ci] fix: dataloader in e2e ckpt test (#233)
1 parent 7e1149b commit 3f1ae7b

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

tests/checkpoints/test_trainer_saveload.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def main():
133133
args.train.compute_train_steps(args.data.max_seq_len, args.data.train_size)
134134
train_dataloader = build_dataloader(
135135
dataset=train_dataset,
136-
dataloader_type="streaming",
136+
dataloader_type="native",
137137
micro_batch_size=args.train.micro_batch_size,
138138
global_batch_size=args.train.global_batch_size,
139139
dataloader_batch_size=args.train.dataloader_batch_size,
@@ -142,17 +142,14 @@ def main():
142142
rmpad=args.train.rmpad,
143143
rmpad_with_pos_ids=args.train.rmpad_with_pos_ids,
144144
bsz_warmup_ratio=args.train.bsz_warmup_ratio,
145-
dyn_bsz_runtime=args.train.dyn_bsz_runtime,
145+
bsz_warmup_init_mbtoken=args.train.bsz_warmup_init_mbtoken,
146146
dyn_bsz_margin=args.train.dyn_bsz_margin,
147147
dyn_bsz_buffer_size=args.train.dyn_bsz_buffer_size,
148148
collate_fn=None,
149-
bsz_warmup_init_mbtoken=args.train.bsz_warmup_init_mbtoken,
150-
infinity=True,
151149
num_workers=args.data.num_workers,
152150
drop_last=args.data.drop_last,
153151
pin_memory=args.data.pin_memory,
154152
prefetch_factor=args.data.prefetch_factor,
155-
drop_resume_buffer=args.data.drop_resume_buffer,
156153
)
157154

158155
logger.info_rank0("Prepare model")
@@ -351,7 +348,7 @@ def test_trainer_saveload_ep8():
351348
"--nnodes=1",
352349
"--nproc_per_node=8",
353350
"--master_port=4321",
354-
"tests/utils/test_trainer_saveload.py",
351+
"tests/checkpoints/test_trainer_saveload.py",
355352
"tests/checkpoints/ep8.yaml",
356353
]
357354
ep8_result = subprocess.run(ep8_command, check=True)

veomni/checkpoint/dcp_checkpointer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,7 @@ def get_state_dict_with_ep_dim_preprocess(self, state_dict, action):
175175
continue
176176

177177
# each tensor in the state dict should only belong to one EP entry
178-
if len(matches) > 1:
179-
raise RuntimeError(f"Ambiguous EP spec match for state key '{name}': {matches}")
178+
assert len(matches) == 1, f"Ambiguous EP spec match for state key '{name}': {matches}"
180179

181180
ep_key = matches[0]
182181
cur_spec_info = ep_fqn2spec_info[ep_key]

0 commit comments

Comments
 (0)