Skip to content

Commit f7b3021

Browse files
committed
rebase & address feedback
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 9d8b2b2 commit f7b3021

File tree

4 files changed

+100
-320
lines changed

4 files changed

+100
-320
lines changed

nemo_rl/models/megatron/data.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from contextlib import nullcontext
1516
from dataclasses import dataclass
1617
from typing import Any, Iterator, Optional, Tuple
1718

@@ -211,7 +212,7 @@ def process_microbatch(
211212
pad_packed_seq_to_multiple_of: int = 1,
212213
pad_full_seq_to: Optional[int] = None,
213214
pack_sequences: bool = False,
214-
straggler_timer: StragglerDetector = None,
215+
straggler_timer: Optional[StragglerDetector] = None,
215216
) -> tuple[
216217
torch.Tensor,
217218
torch.Tensor,
@@ -221,7 +222,8 @@ def process_microbatch(
221222
Optional[torch.Tensor],
222223
]:
223224
"""Process a microbatch for Megatron model forward pass."""
224-
with straggler_timer(bdata=True):
225+
ctx = straggler_timer(bdata=True) if straggler_timer is not None else nullcontext()
226+
with ctx:
225227
input_ids = data_dict["input_ids"]
226228
attention_mask = None
227229
position_ids = None
@@ -294,15 +296,15 @@ def process_global_batch(
294296
*,
295297
batch_idx: int,
296298
batch_size: int,
297-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
299+
) -> dict[str, Any]:
298300
"""Process a global batch and compute normalization factors.
299301
300302
Args:
301-
data: Full dataset
303+
data: Full dataset to extract a batch from
304+
loss_fn: Loss function (used to check loss type for token-level validation)
305+
dp_group: Data parallel process group for all-reduce
302306
batch_idx: Index of batch to extract
303307
batch_size: Size of batch to extract
304-
loss_fn: Loss function (used to check loss type)
305-
dp_mesh: Data parallel mesh
306308
307309
Returns:
308310
Dictionary containing:

nemo_rl/models/megatron/train.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from contextlib import nullcontext
1516
from functools import partial
1617
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
1718

@@ -25,6 +26,7 @@
2526
get_tensor_model_parallel_rank,
2627
)
2728
from megatron.core.pipeline_parallel import get_forward_backward_func
29+
from megatron.core.utils import StragglerDetector
2830

2931
from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper
3032
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
@@ -35,6 +37,7 @@
3537
from_parallel_logits_to_logprobs_packed_sequences,
3638
)
3739
from nemo_rl.models.megatron.data import ProcessedMicrobatch
40+
from nemo_rl.models.policy import PolicyConfig
3841

3942
# Union type for any post-processing function (defined after classes below)
4043
PostProcessingFunction = Union[
@@ -47,24 +50,26 @@
4750
def model_forward(
4851
model: GPTModel,
4952
data_dict: BatchedDataDict[Any],
50-
cfg: Dict[str, Any],
53+
cfg: PolicyConfig,
5154
input_ids_cp_sharded: torch.Tensor,
5255
position_ids: torch.Tensor,
5356
attention_mask: torch.Tensor,
5457
packed_seq_params: Optional[PackedSeqParams] = None,
55-
defer_fp32_logits: Optional[bool] = None,
58+
defer_fp32_logits: Optional[bool] = False,
59+
straggler_timer: Optional[StragglerDetector] = None,
5660
) -> torch.Tensor:
5761
"""Perform a single forward pass through the model.
5862
5963
Args:
6064
model: The model to run forward pass on
61-
data_dict (BatchedDataDict): Dictionary containing batch data
62-
cfg (dict): Configuration dictionary
65+
data_dict: Dictionary containing batch data
66+
cfg: Policy configuration dictionary
6367
input_ids_cp_sharded: Context-parallel sharded input token IDs
6468
position_ids: Position IDs for tokens
6569
attention_mask: Attention mask for the sequence
6670
packed_seq_params: Parameters for packed sequences (optional)
67-
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
71+
defer_fp32_logits: Whether to skip the conversion of logits to fp32
72+
straggler_timer: Straggler detector for profiling the forward pass
6873
6974
Returns:
7075
torch.Tensor: Output tensor from the model (logits)
@@ -81,31 +86,48 @@ def model_forward(
8186
additional_kwargs["packed_seq_params"] = packed_seq_params
8287
if defer_fp32_logits:
8388
additional_kwargs["fp32_output"] = False
84-
# with straggler_timer:
85-
output_tensor = model(
86-
input_ids=input_ids_cp_sharded,
87-
position_ids=position_ids,
88-
attention_mask=attention_mask,
89-
**additional_kwargs,
90-
**multimodal_data,
91-
)
9289

93-
# Apply temperature scaling to logits for training
94-
# This matches the dtensor worker's _apply_temperature_scaling in the train method
95-
if "generation" in cfg and cfg["generation"] is not None:
96-
output_tensor.div_(cfg["generation"]["temperature"])
90+
with straggler_timer() if straggler_timer is not None else nullcontext():
91+
output_tensor = model(
92+
input_ids=input_ids_cp_sharded,
93+
position_ids=position_ids,
94+
attention_mask=attention_mask,
95+
**additional_kwargs,
96+
**multimodal_data,
97+
)
98+
99+
apply_temperature_scaling(output_tensor, cfg)
97100

98101
return output_tensor
99102

100103

104+
def apply_temperature_scaling(
105+
logits: torch.Tensor,
106+
cfg: PolicyConfig,
107+
) -> torch.Tensor:
108+
"""Apply temperature scaling to logits.
109+
110+
Args:
111+
logits: Logits tensor to scale
112+
cfg: Policy configuration containing generation settings
113+
114+
Returns:
115+
torch.Tensor: Temperature-scaled logits
116+
"""
117+
if "generation" in cfg and cfg["generation"] is not None:
118+
logits.div_(cfg["generation"]["temperature"])
119+
return logits
120+
121+
101122
def forward_with_post_processing_fn(
102123
data_iterator: Iterator[ProcessedMicrobatch],
103124
model: GPTModel,
104-
cfg: Dict[str, Any],
125+
cfg: PolicyConfig,
105126
post_processing_fn: PostProcessingFunction,
106-
defer_fp32_logits: Optional[bool] = True,
127+
defer_fp32_logits: Optional[bool] = False,
107128
global_valid_seqs: Optional[torch.Tensor] = None,
108129
global_valid_toks: Optional[torch.Tensor] = None,
130+
straggler_timer: Optional[StragglerDetector] = None,
109131
) -> Tuple[torch.Tensor, Callable]:
110132
"""Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.
111133
@@ -116,11 +138,12 @@ def forward_with_post_processing_fn(
116138
Args:
117139
data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
118140
model: The model to run forward pass on
119-
cfg (dict): Configuration dictionary
141+
cfg: Policy configuration dictionary
120142
post_processing_fn: Post-processing function to post-process the logits
121143
defer_fp32_logits: Whether to defer FP32 conversion of logits
122144
global_valid_seqs: Global valid sequence count for loss normalization
123145
global_valid_toks: Global valid token count for loss normalization
146+
straggler_timer: Straggler detector for profiling the forward pass
124147
125148
Returns:
126149
tuple: (output_tensor, post_processing_fn_wrapped)
@@ -140,14 +163,15 @@ def forward_with_post_processing_fn(
140163
cu_seqlens_padded = processed_mb.cu_seqlens_padded
141164

142165
output_tensor = model_forward(
143-
model,
144-
data_dict,
145-
cfg,
146-
input_ids_cp_sharded,
147-
position_ids,
148-
attention_mask,
149-
packed_seq_params,
150-
defer_fp32_logits,
166+
model=model,
167+
data_dict=data_dict,
168+
cfg=cfg,
169+
input_ids_cp_sharded=input_ids_cp_sharded,
170+
position_ids=position_ids,
171+
attention_mask=attention_mask,
172+
packed_seq_params=packed_seq_params,
173+
defer_fp32_logits=defer_fp32_logits,
174+
straggler_timer=straggler_timer,
151175
)
152176

153177
## calling post_processing_fn will return a function that takes the output tensor and returns a tuple of (loss, metrics)
@@ -180,17 +204,18 @@ def forward_with_post_processing_fn(
180204

181205
def megatron_forward_backward(
182206
model: GPTModel,
183-
cfg: Dict[str, Any],
207+
cfg: PolicyConfig,
184208
data_iterator: Iterator[ProcessedMicrobatch],
185209
num_microbatches: int,
186210
seq_length: int,
187211
mbs: int,
188212
post_processing_fn: PostProcessingFunction,
189213
forward_only: bool = False,
190-
defer_fp32_logits: Optional[bool] = None,
214+
defer_fp32_logits: Optional[bool] = False,
191215
global_valid_seqs: Optional[torch.Tensor] = None,
192216
global_valid_toks: Optional[torch.Tensor] = None,
193217
do_not_average_loss: bool = False,
218+
straggler_timer: Optional[StragglerDetector] = None,
194219
) -> Any:
195220
"""Execute forward and backward passes using Megatron's utilities.
196221
@@ -200,19 +225,21 @@ def megatron_forward_backward(
200225
201226
Args:
202227
model: The model to train
203-
cfg (dict): Configuration dictionary
228+
cfg: Policy configuration dictionary
204229
data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
205-
num_microbatches (int): Number of microbatches to process
206-
seq_length (int): Sequence length
207-
mbs (int): Micro batch size
230+
num_microbatches: Number of microbatches to process
231+
seq_length: Sequence length
232+
mbs: Micro batch size
208233
post_processing_fn: Post-processing function to post-process the logits
209-
forward_only (bool): If True, skip backward pass
210-
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
234+
forward_only: If True, skip backward pass
235+
defer_fp32_logits: Whether to skip the conversion of logits to fp32
211236
global_valid_seqs: Global valid sequence count for loss normalization
212237
global_valid_toks: Global valid token count for loss normalization
238+
do_not_average_loss: If True, do not average loss across microbatches
239+
straggler_timer: Straggler detector for profiling the forward pass
213240
214241
Returns:
215-
BatchedDataDict: Results from the forward/backward execution
242+
Results from the forward/backward execution
216243
"""
217244
forward_step = partial(
218245
forward_with_post_processing_fn,
@@ -221,6 +248,7 @@ def megatron_forward_backward(
221248
defer_fp32_logits=defer_fp32_logits,
222249
global_valid_seqs=global_valid_seqs,
223250
global_valid_toks=global_valid_toks,
251+
straggler_timer=straggler_timer,
224252
)
225253
forward_backward_func = get_forward_backward_func()
226254
return forward_backward_func(
@@ -240,7 +268,7 @@ class LossPostProcessor:
240268
def __init__(
241269
self,
242270
loss_fn: LossFunction,
243-
cfg: Dict[str, Any],
271+
cfg: PolicyConfig,
244272
cp_normalize: bool = True,
245273
):
246274
self.loss_fn = loss_fn
@@ -261,11 +289,10 @@ def __call__(
261289
and context parallelism normalization.
262290
263291
Args:
264-
loss_fn: The base loss function to wrap
265-
cfg (dict): Configuration dictionary
266-
data_dict: Batched data dictionary
292+
data_dict: Batched data dictionary for the current microbatch
267293
packed_seq_params: Parameters for packed sequences (optional)
268-
cp_normalize (bool): Whether to normalize by context parallel size
294+
global_valid_seqs: Global valid sequence count for loss normalization
295+
global_valid_toks: Global valid token count for loss normalization
269296
270297
Returns:
271298
Callable: Function that takes output tensor and returns (loss, metrics) tuple
@@ -304,7 +331,7 @@ def _div_by_cp_size(*args, **kwargs):
304331

305332

306333
class LogprobsPostProcessor:
307-
def __init__(self, cfg: Dict[str, Any]):
334+
def __init__(self, cfg: PolicyConfig):
308335
self.cfg = cfg
309336

310337
def __call__(
@@ -369,7 +396,7 @@ def processor_fn_inner(output_tensor):
369396

370397

371398
class TopkLogitsPostProcessor:
372-
def __init__(self, cfg: Dict[str, Any], k: int):
399+
def __init__(self, cfg: PolicyConfig, k: int):
373400
self.cfg = cfg
374401
self.k = k
375402

0 commit comments

Comments
 (0)