@@ -274,10 +274,6 @@ def freeze_moe_router(megatron_model):
274274 mixed_precision_wrapper = CustomFloat16Module
275275 pre_wrap_hook .extend ([freeze_moe_router ])
276276
277- # If deferring fp32 logits, disable mixed-precision wrapper entirely
278- if policy_cfg ["megatron_cfg" ].get ("defer_fp32_logits" , None ):
279- mixed_precision_wrapper = None
280-
281277 # Model, optimizer, and learning rate.
282278 model = get_model (
283279 cfg .model ,
@@ -663,6 +659,9 @@ def __init__(
663659 assert self .cfg ["megatron_cfg" ]["defer_fp32_logits" ], (
664660 "defer_fp32_logits must be True if logprob_chunk_size is set"
665661 )
662+ self .defer_fp32_logits = self .cfg ["megatron_cfg" ].get (
663+ "defer_fp32_logits" , None
664+ ) and (model_cfg .fp16 or model_cfg .bf16 )
666665
667666 checkpoint_config = CheckpointConfig (
668667 save_interval = 100 ,
@@ -796,8 +795,6 @@ def __init__(
796795 ref_mixed_precision_wrapper = Float16Module
797796 if self .cfg ["megatron_cfg" ].get ("freeze_moe_router" , False ):
798797 ref_mixed_precision_wrapper = CustomFloat16Module
799- if self .cfg ["megatron_cfg" ].get ("defer_fp32_logits" , None ):
800- ref_mixed_precision_wrapper = None
801798
802799 reference_model = get_model (
803800 self .megatron_cfg .model ,
@@ -1068,6 +1065,7 @@ def train(
10681065 pad_individual_seqs_to_multiple_of = pad_factor ,
10691066 pad_packed_seq_to_multiple_of = pad_packed_seq_to_multiple_of ,
10701067 pad_full_seq_to = pad_full_seq_to ,
1068+ defer_fp32_logits = self .defer_fp32_logits ,
10711069 ),
10721070 data_iterator = data_iterator ,
10731071 model = self .model ,
@@ -1284,6 +1282,9 @@ def forward_step_fn(
12841282 if packed_seq_params is not None :
12851283 additional_kwargs ["packed_seq_params" ] = packed_seq_params
12861284
1285+ if self .defer_fp32_logits :
1286+ additional_kwargs ["fp32_output" ] = False
1287+
12871288 output_tensor = model (
12881289 input_ids = input_ids_cp_sharded ,
12891290 position_ids = position_ids ,
0 commit comments