Skip to content

Commit 76064bd

Browse files
committed
update megatron
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent f44e731 commit 76064bd

File tree

6 files changed

+89
-34
lines changed

6 files changed

+89
-34
lines changed

nemo_rl/models/automodel/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -547,14 +547,15 @@ def __call__(
547547
logits, self.device_mesh, self.cp_mesh, sequence_dim
548548
)
549549

550-
# Wrap loss function for sequence packing if needed
551-
wrapped_prepare_loss_input = partial(
550+
# Wrap prepare_loss_input with sampling_params
551+
prepare_loss_input_wrapped = partial(
552552
prepare_loss_input, sampling_params=self.sampling_params
553553
)
554+
# Wrap loss function for sequence packing if needed
554555
if self.enable_seq_packing:
555556
loss_fn = SequencePackingLossWrapper(
556557
loss_fn=self.loss_fn,
557-
prepare_fn=wrapped_prepare_loss_input,
558+
prepare_fn=prepare_loss_input_wrapped,
558559
cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q,
559560
cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q,
560561
)
@@ -565,7 +566,7 @@ def __call__(
565566
global_valid_toks,
566567
)
567568
else:
568-
loss_input, mb = wrapped_prepare_loss_input(logits, mb, self.loss_fn)
569+
loss_input, mb = prepare_loss_input_wrapped(logits, mb, self.loss_fn)
569570
loss, loss_metrics = self.loss_fn(
570571
data=mb,
571572
global_valid_seqs=global_valid_seqs,

nemo_rl/models/megatron/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
2222
from megatron.core.transformer import MegatronModule
2323

24+
from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams
25+
2426

2527
class MegatronGenerationConfig(TypedDict):
2628
# Total GPU memory (in GB) allocated for KV cache buffers
@@ -55,6 +57,7 @@ class RuntimeConfig(NamedTuple):
5557
optimizer_cpu_offload: bool
5658
offload_optimizer_for_logprob: bool
5759
is_generation_colocated: Optional[bool]
60+
sampling_params: Optional[TrainingSamplingParams]
5861
final_padded_vocab_size: int
5962

6063

nemo_rl/models/megatron/setup.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
except ImportError:
6767
HAVE_FSDP2 = False
6868

69+
from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams
6970
from nemo_rl.distributed.named_sharding import NamedSharding
7071
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
7172
from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig
@@ -194,7 +195,6 @@ def validate_and_set_config(
194195
hf_model_name,
195196
pretrained_path,
196197
weights_path,
197-
tokenizer,
198198
):
199199
# Handle generation colocation
200200
is_generation_colocated = None
@@ -218,6 +218,16 @@ def validate_and_set_config(
218218
optimizer_cpu_offload = config["megatron_cfg"]["optimizer"]["optimizer_cpu_offload"]
219219
offload_optimizer_for_logprob = config["offload_optimizer_for_logprob"]
220220

221+
# Sampling parameters configuration
222+
sampling_params = None
223+
if "generation" in config and config["generation"] is not None:
224+
generation_cfg = config["generation"]
225+
sampling_params = TrainingSamplingParams(
226+
top_k=generation_cfg.get("top_k", None),
227+
top_p=generation_cfg.get("top_p", 1.0),
228+
temperature=generation_cfg.get("temperature", 1.0),
229+
)
230+
221231
# Reward models are not yet supported with Megatron.
222232
if "reward_model_cfg" in config and config["reward_model_cfg"]["enabled"]:
223233
raise NotImplementedError(
@@ -242,6 +252,7 @@ def validate_and_set_config(
242252
optimizer_cpu_offload,
243253
offload_optimizer_for_logprob,
244254
is_generation_colocated,
255+
sampling_params,
245256
final_padded_vocab_size,
246257
)
247258

nemo_rl/models/megatron/train.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from megatron.core.pipeline_parallel import get_forward_backward_func
3030
from megatron.core.utils import StragglerDetector
3131

32+
from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams
3233
from nemo_rl.algorithms.loss import (
3334
SequencePackingLossWrapper,
3435
prepare_loss_input,
@@ -56,7 +57,6 @@
5657
def model_forward(
5758
model: GPTModel,
5859
data_dict: BatchedDataDict[Any],
59-
cfg: PolicyConfig,
6060
input_ids_cp_sharded: torch.Tensor,
6161
position_ids: torch.Tensor,
6262
attention_mask: torch.Tensor,
@@ -106,27 +106,26 @@ def model_forward(
106106

107107

108108
def apply_temperature_scaling(
109-
logits: torch.Tensor,
110-
cfg: PolicyConfig,
109+
logits: torch.Tensor, sampling_params: Optional[TrainingSamplingParams]
111110
) -> torch.Tensor:
112111
"""Apply temperature scaling to logits.
113112
114113
Args:
115114
logits: Logits tensor to scale
116-
cfg: Policy configuration containing generation settings
115+
sampling_params: Sampling parameters
117116
118117
Returns:
119118
torch.Tensor: Temperature-scaled logits
120119
"""
121-
if "generation" in cfg and cfg["generation"] is not None:
122-
logits.div_(cfg["generation"]["temperature"])
120+
if sampling_params is not None and sampling_params.temperature != 1.0:
121+
logits.div_(sampling_params.temperature)
123122
return logits
124123

125124

126125
def forward_with_post_processing_fn(
127126
data_iterator: Iterator[ProcessedMicrobatch],
128127
model: GPTModel,
129-
cfg: PolicyConfig,
128+
sampling_params: TrainingSamplingParams,
130129
post_processing_fn: PostProcessingFunction,
131130
defer_fp32_logits: Optional[bool] = False,
132131
global_valid_seqs: Optional[torch.Tensor] = None,
@@ -142,7 +141,7 @@ def forward_with_post_processing_fn(
142141
Args:
143142
data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
144143
model: The model to run forward pass on
145-
cfg: Policy configuration dictionary
144+
sampling_params: Sampling parameters
146145
post_processing_fn: Post-processing function to post-process the logits
147146
defer_fp32_logits: Whether to defer FP32 conversion of logits
148147
global_valid_seqs: Global valid sequence count for loss normalization
@@ -169,7 +168,6 @@ def forward_with_post_processing_fn(
169168
output_tensor = model_forward(
170169
model=model,
171170
data_dict=data_dict,
172-
cfg=cfg,
173171
input_ids_cp_sharded=input_ids_cp_sharded,
174172
position_ids=position_ids,
175173
attention_mask=attention_mask,
@@ -187,7 +185,7 @@ def forward_with_post_processing_fn(
187185
# Temperature scaling is element-wise, directly applying it here.
188186
# Other sampling parameters like top-k and top-p need the logits from whole vocabulary,
189187
# so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor).
190-
apply_temperature_scaling(output_tensor, cfg)
188+
apply_temperature_scaling(output_tensor, sampling_params)
191189

192190
# Use type checking to dispatch to the correct post-processing method
193191
if isinstance(post_processing_fn, LossPostProcessor):
@@ -218,7 +216,7 @@ def forward_with_post_processing_fn(
218216

219217
def megatron_forward_backward(
220218
model: GPTModel,
221-
cfg: PolicyConfig,
219+
sampling_params: TrainingSamplingParams,
222220
data_iterator: Iterator[ProcessedMicrobatch],
223221
num_microbatches: int,
224222
seq_length: int,
@@ -238,7 +236,7 @@ def megatron_forward_backward(
238236
239237
Args:
240238
model: The model to train
241-
cfg: Policy configuration dictionary
239+
sampling_params: Sampling parameters
242240
data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
243241
num_microbatches: Number of microbatches to process
244242
seq_length: Sequence length
@@ -255,7 +253,7 @@ def megatron_forward_backward(
255253
"""
256254
forward_step = partial(
257255
forward_with_post_processing_fn,
258-
cfg=cfg,
256+
sampling_params=sampling_params,
259257
post_processing_fn=post_processing_fn,
260258
defer_fp32_logits=defer_fp32_logits,
261259
global_valid_seqs=global_valid_seqs,
@@ -282,11 +280,13 @@ def __init__(
282280
cfg: PolicyConfig,
283281
num_microbatches: int = 1,
284282
cp_normalize: bool = True,
283+
sampling_params: Optional[TrainingSamplingParams] = None,
285284
):
286285
self.loss_fn = loss_fn
287286
self.cfg = cfg
288287
self.num_microbatches = num_microbatches
289288
self.cp_normalize = cp_normalize
289+
self.sampling_params = sampling_params
290290

291291
def __call__(
292292
self,
@@ -310,12 +310,17 @@ def __call__(
310310
Returns:
311311
Callable: Function that takes output tensor and returns (loss, metrics) tuple
312312
"""
313+
# wrap prepare_loss_input with sampling_params
314+
prepare_loss_input_wrapped = partial(
315+
prepare_loss_input, sampling_params=self.sampling_params
316+
)
317+
313318
# wrap loss function with loss input preparation
314319
pack_sequences = self.cfg["sequence_packing"]["enabled"]
315320
if pack_sequences and packed_seq_params is not None:
316321
loss_fn_wrapped = SequencePackingLossWrapper(
317322
loss_fn=self.loss_fn,
318-
prepare_fn=prepare_loss_input,
323+
prepare_fn=prepare_loss_input_wrapped,
319324
cu_seqlens_q=packed_seq_params.cu_seqlens_q,
320325
cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded,
321326
vocab_parallel_rank=get_tensor_model_parallel_rank(),
@@ -326,7 +331,7 @@ def __call__(
326331
loss_fn_wrapped = partial(
327332
wrap_loss_fn_with_input_preparation,
328333
loss_fn=self.loss_fn,
329-
prepare_fn=prepare_loss_input,
334+
prepare_fn=prepare_loss_input_wrapped,
330335
vocab_parallel_rank=get_tensor_model_parallel_rank(),
331336
vocab_parallel_group=get_tensor_model_parallel_group(),
332337
context_parallel_group=get_context_parallel_group(),
@@ -365,8 +370,9 @@ def _counteract_mcore_loss_averaging(*args, **kwargs):
365370

366371

367372
class LogprobsPostProcessor:
368-
def __init__(self, cfg: PolicyConfig):
373+
def __init__(self, cfg: PolicyConfig, sampling_params: TrainingSamplingParams):
369374
self.cfg = cfg
375+
self.sampling_params = sampling_params
370376

371377
def __call__(
372378
self,
@@ -406,6 +412,7 @@ def processor_fn_inner(output_tensor):
406412
inference_only=True,
407413
cp_group=get_context_parallel_group(),
408414
chunk_size=logprob_chunk_size,
415+
sampling_params=self.sampling_params,
409416
)
410417
else:
411418
token_logprobs = from_parallel_logits_to_logprobs(
@@ -416,6 +423,7 @@ def processor_fn_inner(output_tensor):
416423
tp_group=tp_grp,
417424
inference_only=True,
418425
chunk_size=logprob_chunk_size,
426+
sampling_params=self.sampling_params,
419427
)
420428

421429
# Prepend 0 logprob for first token to maintain same sequence length as input

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -781,8 +781,10 @@ def get_topk_logits(
781781
def use_reference_model(self) -> Generator[None, None, None]:
782782
"""Context manager that temporarily swaps the reference model and active model.
783783
784-
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references
785-
On exit: Restores original references and re-flips cuda/cpu
784+
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references.
785+
Also disables top-k/top-p filtering since the reference policy's distribution
786+
is different from the current policy, making filtered logprobs incompatible.
787+
On exit: Restores original references and re-flips cuda/cpu, restores sampling_params.
786788
"""
787789
with torch.no_grad():
788790
try:
@@ -796,10 +798,11 @@ def use_reference_model(self) -> Generator[None, None, None]:
796798
val = to_local_if_dtensor(v)
797799
val.copy_(self.reference_model_state_dict[k])
798800

799-
# - self.model is the original reference_model, now on CUDA
800-
# - curr_state_dict is the train model, now on CPU
801-
802-
# Save and adjust sampling_params for reference model
801+
# Temporarily disable top-k/top-p filtering for reference policy logprobs.
802+
# The reference policy has different weights, so its top-k/top-p set is
803+
# inherently different from the current policy. Using filtered logprobs
804+
# would cause -inf mismatches that cannot be resolved by masking.
805+
# Note: We keep temperature scaling since it was applied to prev_logprobs.
803806
saved_sampling_params = self.sampling_params
804807
if saved_sampling_params is not None:
805808
self.sampling_params = TrainingSamplingParams(
@@ -810,6 +813,8 @@ def use_reference_model(self) -> Generator[None, None, None]:
810813
else:
811814
self.sampling_params = None
812815

816+
# - self.model is the original reference_model, now on CUDA
817+
# - curr_state_dict is the train model, now on CPU
813818
yield
814819

815820
finally:

0 commit comments

Comments
 (0)