Skip to content

Commit 083873c

Browse files
authored
[ckpt] fix: Add missing broadcast_model_weights_from_rank0 option for build_parallelize_model() (#548)
1 parent 01be667 commit 083873c

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

docs/usage/support_new_models/guide_and_checklist.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
**TLDR:** VeOmni patches HuggingFace models at runtime to add FSDP, Sequence Parallelism (SP), Expert Parallelism (EP), and fused kernels. This guide walks you through the integration steps with checklists per model type. For worked examples, see:
44
- [qwen3_vl_example.md](./qwen3_vl_example.md) — VLM + MoE (image/video, deepstack, EP)
55
- [qwen3_omni_moe_example.md](./qwen3_omni_moe_example.md) — Omni-modal MoE (image/video/audio, talker)
6-
6+
77
> **Scope note:** This guide currently targets the **transformers v4** integration/patchgen flow.
88
> **TODO:** Add a dedicated **transformers v5** section, since modeling code patchgen requires a slightly different approach.
99

veomni/arguments/arguments_types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,13 @@ def _validate_accelerator(self):
467467
)
468468
if acc.fsdp_config.fsdp_mode == "fsdp2":
469469
assert self.init_device == "meta", "Please use init_device: meta for FSDP2 training"
470+
else:
471+
if self.broadcast_model_weights_from_rank0:
472+
logger.warning_rank0(
473+
"Ignoring train.broadcast_model_weights_from_rank0=True because it is only "
474+
"used with train.accelerator.fsdp_config.fsdp_mode='fsdp2'. "
475+
f"Received fsdp_mode={acc.fsdp_config.fsdp_mode!r}. Disable this flag or switch to fsdp2.",
476+
)
470477

471478
def _derive_batch_config(self):
472479
acc = self.accelerator

veomni/trainer/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def _build_parallelized_model(self):
325325
),
326326
enable_reentrant=args.train.gradient_checkpointing.enable_reentrant,
327327
enable_forward_prefetch=args.train.accelerator.fsdp_config.forward_prefetch,
328+
broadcast_model_weights_from_rank0=args.train.broadcast_model_weights_from_rank0,
328329
)
329330
self.model.train()
330331

0 commit comments

Comments
 (0)