Skip to content

Commit a0755eb

Browse files
authored
fix: Use Float16Module even when defer_fp32_logits=True (#1537)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent 40e7040 commit a0755eb

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

nemo_rl/models/megatron/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def forward_step_arbitrary_loss(
358358
pad_individual_seqs_to_multiple_of: int = 1,
359359
pad_packed_seq_to_multiple_of: int = 1,
360360
pad_full_seq_to: Optional[int] = None,
361+
defer_fp32_logits: Optional[bool] = None,
361362
cp_normalize: bool = True,
362363
policy_cfg: Optional[dict] = None,
363364
):
@@ -372,6 +373,9 @@ def forward_step_arbitrary_loss(
372373
loss_fn (LossFunction): Loss function to apply
373374
pack_sequences (bool): Whether to pack sequences for efficiency
374375
seq_length_key (Optional[str]): Key in data_dict containing actual sequence lengths
376+
pad_individual_seqs_to_multiple_of (int): Pad individual sequences to a multiple of this value
377+
pad_full_seq_to (Optional[int]): Pad packed sequences to this value
378+
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
375379
cp_normalize (bool): Whether to normalize the loss by the cp_size
376380
policy_cfg (Optional[dict]): Policy configuration containing generation parameters
377381
@@ -453,6 +457,9 @@ def forward_step_arbitrary_loss(
453457
if packed_seq_params is not None:
454458
additional_kwargs["packed_seq_params"] = packed_seq_params
455459

460+
if defer_fp32_logits:
461+
additional_kwargs["fp32_output"] = False
462+
456463
with straggler_timer:
457464
output_tensor = model(
458465
input_ids=input_ids_cp_sharded,

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,6 @@ def freeze_moe_router(megatron_model):
274274
mixed_precision_wrapper = CustomFloat16Module
275275
pre_wrap_hook.extend([freeze_moe_router])
276276

277-
# If deferring fp32 logits, disable mixed-precision wrapper entirely
278-
if policy_cfg["megatron_cfg"].get("defer_fp32_logits", None):
279-
mixed_precision_wrapper = None
280-
281277
# Model, optimizer, and learning rate.
282278
model = get_model(
283279
cfg.model,
@@ -663,6 +659,9 @@ def __init__(
663659
assert self.cfg["megatron_cfg"]["defer_fp32_logits"], (
664660
"defer_fp32_logits must be True if logprob_chunk_size is set"
665661
)
662+
self.defer_fp32_logits = self.cfg["megatron_cfg"].get(
663+
"defer_fp32_logits", None
664+
) and (model_cfg.fp16 or model_cfg.bf16)
666665

667666
checkpoint_config = CheckpointConfig(
668667
save_interval=100,
@@ -796,8 +795,6 @@ def __init__(
796795
ref_mixed_precision_wrapper = Float16Module
797796
if self.cfg["megatron_cfg"].get("freeze_moe_router", False):
798797
ref_mixed_precision_wrapper = CustomFloat16Module
799-
if self.cfg["megatron_cfg"].get("defer_fp32_logits", None):
800-
ref_mixed_precision_wrapper = None
801798

802799
reference_model = get_model(
803800
self.megatron_cfg.model,
@@ -1068,6 +1065,7 @@ def train(
10681065
pad_individual_seqs_to_multiple_of=pad_factor,
10691066
pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of,
10701067
pad_full_seq_to=pad_full_seq_to,
1068+
defer_fp32_logits=self.defer_fp32_logits,
10711069
),
10721070
data_iterator=data_iterator,
10731071
model=self.model,
@@ -1284,6 +1282,9 @@ def forward_step_fn(
12841282
if packed_seq_params is not None:
12851283
additional_kwargs["packed_seq_params"] = packed_seq_params
12861284

1285+
if self.defer_fp32_logits:
1286+
additional_kwargs["fp32_output"] = False
1287+
12871288
output_tensor = model(
12881289
input_ids=input_ids_cp_sharded,
12891290
position_ids=position_ids,

0 commit comments

Comments
 (0)