-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
Describe the bug
During mixed-precision training (BF16 & FP32), a RuntimeError (size mismatch) is triggered when saving a checkpoint via get_parameter_state_dp_zero.
Based on our analysis, the hardcoded parameter reordering in _build_model_and_main_param_groups within DistributedOptimizer causes the self.model_param_group_index_map to become out of sync with the actual optimizer.param_groups.
Steps/Code to reproduce bug
- Before building the DDP model and optimizer,use
Float16Moduleto wrap the model for BF16 training, but manually promote certain modules (both params and inputs) to FP32. - Train for several steps.
- Call
get_parameter_state_dp_zeroinDistributedOptimizerto collect optimizer states, which triggers the size mismatch error.
Root Cause Analysis
-
Initial Map Construction: In
__init__,self.model_param_group_index_mapis first constructed via_build_optimizer_group_ranges. This map records the position(group_index, group_order)of param in param_groups -
Hardcoded Reordering: Subsequently,
_build_model_and_main_param_groupsreorders the parameters within each group (placing native FP32 shards at the front and main parameter shards converted from FP16/BF16 at the back) and updatesoptimizer.param_groupsaccordingly:
- Index Invalidation: The
model_param_group_index_mapis not updated after this reordering. Consequently, downstream functions like_get_main_param_and_optimizer_statesretrieve the wrong Tensors using stale group_order indices, leading to shape mismatches during buffer copy operations.
Additional question
- What is the design motivation behind this specific reordering (grouping by DType)?
- What is the recommended best practice to fix this: disabling the reordering to maintain discovery order consistency, or explicitly updating the
model_param_group_index_mapafter the reordering is performed?