Skip to content

Commit f510015

Browse files
committed
lint
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent 9c7cab8 commit f510015

File tree

6 files changed

+130
-92
lines changed

6 files changed

+130
-92
lines changed

nemo_rl/models/megatron/common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Any
15+
from typing import Any, Optional
1616

1717
import torch
1818
import torch.distributed as dist
19-
2019
from megatron.core.transformer.moe.moe_utils import (
2120
clear_aux_losses_tracker,
2221
get_moe_layer_wise_logging_tracker,
@@ -280,4 +279,4 @@ def get_moe_metrics(
280279
metrics[f"moe/{name}_layer_{i}"] = float(loss)
281280

282281
clear_aux_losses_tracker()
283-
return metrics
282+
return metrics

nemo_rl/models/megatron/pipeline_parallel.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,19 @@
2424
is_pipeline_last_stage,
2525
)
2626

27+
2728
def broadcast_obj_from_pp_rank(obj: Any) -> Any:
2829
"""Broadcast an object across pipeline parallel ranks.
2930
This utility function handles broadcasting an object from the rank that owns it
3031
to all other pipeline parallel ranks. If only one rank has the object (non-None),
3132
it will be broadcast to all other ranks.
33+
3234
Args:
3335
obj: The object to broadcast. Can be None on ranks that don't own it.
36+
3437
Returns:
3538
The object on all ranks (either the original or the broadcast copy).
39+
3640
Raises:
3741
ValueError: If the object doesn't exist on any pipeline parallel rank.
3842
"""
@@ -72,21 +76,22 @@ def broadcast_obj_from_pp_rank(obj: Any) -> Any:
7276

7377
return obj_list[0]
7478

79+
7580
def broadcast_loss_metrics_from_last_stage(loss_metrics: Optional[list] = None) -> list:
7681
"""Broadcast loss metrics from the last pipeline stage to all stages.
77-
82+
7883
This utility handles the common pattern where loss computation happens on the last
7984
pipeline stage and needs to be broadcast to all other stages.
80-
85+
8186
Args:
8287
loss_metrics: List of loss metrics if on last stage, None otherwise
83-
88+
8489
Returns:
8590
List of loss metrics on all ranks
8691
"""
8792
pp_group = get_pipeline_model_parallel_group()
8893
last_rank = get_pipeline_model_parallel_last_rank()
89-
94+
9095
if is_pipeline_last_stage(ignore_virtual=True):
9196
metrics_to_broadcast = [loss_metrics]
9297
torch.distributed.broadcast_object_list(
@@ -106,36 +111,38 @@ def broadcast_loss_metrics_from_last_stage(loss_metrics: Optional[list] = None)
106111

107112

108113
def broadcast_tensors_from_last_stage(
109-
tensors: dict[str, Optional[torch.Tensor]],
114+
tensors: dict[str, Optional[torch.Tensor]],
110115
) -> dict[str, torch.Tensor]:
111116
"""Broadcast multiple tensors from the last pipeline stage to all stages.
112-
117+
113118
Args:
114119
tensors: Dictionary mapping tensor names to tensors (None on non-last stages)
115120
pp_group: Pipeline parallel group (auto-detected if None)
116-
121+
117122
Returns:
118123
Dictionary of broadcasted tensors on all ranks
119124
"""
120125
pp_group = get_pipeline_model_parallel_group()
121-
126+
122127
from nemo_rl.models.megatron.common import broadcast_tensor
123-
128+
124129
last_rank = get_pipeline_model_parallel_last_rank()
125130
current_rank = torch.distributed.get_rank()
126-
131+
127132
broadcasted_tensors = {}
128-
133+
129134
if is_pipeline_last_stage(ignore_virtual=True):
130135
# Broadcast tensors from last stage
131136
for name, tensor in tensors.items():
132137
if tensor is not None:
133-
broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group)
138+
broadcasted_tensors[name] = broadcast_tensor(
139+
tensor, current_rank, pp_group
140+
)
134141
else:
135142
broadcasted_tensors[name] = None
136143
else:
137144
# Receive tensors on other stages
138145
for name in tensors.keys():
139146
broadcasted_tensors[name] = broadcast_tensor(None, last_rank, pp_group)
140-
141-
return broadcasted_tensors
147+
148+
return broadcasted_tensors

nemo_rl/models/megatron/train.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
1717

1818
import torch
19-
2019
from megatron.core.models.gpt import GPTModel
2120
from megatron.core.packed_seq_params import PackedSeqParams
2221
from megatron.core.parallel_state import (
@@ -37,7 +36,6 @@
3736
)
3837
from nemo_rl.models.megatron.data import ProcessedMicrobatch
3938

40-
4139
# Union type for any post-processing function (defined after classes below)
4240
PostProcessingFunction = Union[
4341
"LossPostProcessor",
@@ -56,9 +54,8 @@ def model_forward(
5654
packed_seq_params: Optional[PackedSeqParams] = None,
5755
defer_fp32_logits: Optional[bool] = None,
5856
) -> torch.Tensor:
59-
"""
60-
Perform a single forward pass through the model.
61-
57+
"""Perform a single forward pass through the model.
58+
6259
Args:
6360
model: The model to run forward pass on
6461
data_dict (BatchedDataDict): Dictionary containing batch data
@@ -68,7 +65,7 @@ def model_forward(
6865
attention_mask: Attention mask for the sequence
6966
packed_seq_params: Parameters for packed sequences (optional)
7067
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
71-
68+
7269
Returns:
7370
torch.Tensor: Output tensor from the model (logits)
7471
"""
@@ -84,7 +81,7 @@ def model_forward(
8481
additional_kwargs["packed_seq_params"] = packed_seq_params
8582
if defer_fp32_logits:
8683
additional_kwargs["fp32_output"] = False
87-
#with straggler_timer:
84+
# with straggler_timer:
8885
output_tensor = model(
8986
input_ids=input_ids_cp_sharded,
9087
position_ids=position_ids,
@@ -95,14 +92,12 @@ def model_forward(
9592

9693
# Apply temperature scaling to logits for training
9794
# This matches the dtensor worker's _apply_temperature_scaling in the train method
98-
if (
99-
"generation" in cfg
100-
and cfg["generation"] is not None
101-
):
95+
if "generation" in cfg and cfg["generation"] is not None:
10296
output_tensor.div_(cfg["generation"]["temperature"])
103-
97+
10498
return output_tensor
10599

100+
106101
def forward_with_post_processing_fn(
107102
data_iterator: Iterator[ProcessedMicrobatch],
108103
model: GPTModel,
@@ -112,22 +107,21 @@ def forward_with_post_processing_fn(
112107
global_valid_seqs: Optional[torch.Tensor] = None,
113108
global_valid_toks: Optional[torch.Tensor] = None,
114109
) -> Tuple[torch.Tensor, Callable]:
115-
"""
116-
Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.
117-
110+
"""Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.
111+
118112
This function takes a pre-processed microbatch (with sequence packing already handled),
119113
runs the forward step through the model, and prepares a post-processing function for
120114
post-processing the outputs.
121-
115+
122116
Args:
123117
data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
124-
model: The model to run forward pass on
118+
model: The model to run forward pass on
125119
cfg (dict): Configuration dictionary
126120
post_processing_fn: Post-processing function to post-process the logits
127121
defer_fp32_logits: Whether to defer FP32 conversion of logits
128122
global_valid_seqs: Global valid sequence count for loss normalization
129123
global_valid_toks: Global valid token count for loss normalization
130-
124+
131125
Returns:
132126
tuple: (output_tensor, post_processing_fn_wrapped)
133127
- output_tensor: Raw model outputs (logits)
@@ -177,10 +171,13 @@ def forward_with_post_processing_fn(
177171
cu_seqlens_padded=cu_seqlens_padded,
178172
)
179173
else:
180-
raise TypeError(f"Unknown post-processing function type: {type(post_processing_fn)}")
174+
raise TypeError(
175+
f"Unknown post-processing function type: {type(post_processing_fn)}"
176+
)
181177

182178
return output_tensor, post_processing_fn_wrapped
183179

180+
184181
def megatron_forward_backward(
185182
model: GPTModel,
186183
cfg: Dict[str, Any],
@@ -195,13 +192,12 @@ def megatron_forward_backward(
195192
global_valid_toks: Optional[torch.Tensor] = None,
196193
do_not_average_loss: bool = False,
197194
) -> Any:
198-
"""
199-
Execute forward and backward passes using Megatron's utilities.
200-
195+
"""Execute forward and backward passes using Megatron's utilities.
196+
201197
This is the main training loop function that coordinates forward and backward
202198
passes across multiple microbatches using Megatron's pipeline parallel
203199
execution framework.
204-
200+
205201
Args:
206202
model: The model to train
207203
cfg (dict): Configuration dictionary
@@ -214,7 +210,7 @@ def megatron_forward_backward(
214210
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
215211
global_valid_seqs: Global valid sequence count for loss normalization
216212
global_valid_toks: Global valid token count for loss normalization
217-
213+
218214
Returns:
219215
BatchedDataDict: Results from the forward/backward execution
220216
"""
@@ -239,8 +235,8 @@ def megatron_forward_backward(
239235
do_not_average_loss=do_not_average_loss,
240236
)
241237

242-
class LossPostProcessor:
243238

239+
class LossPostProcessor:
244240
def __init__(
245241
self,
246242
loss_fn: LossFunction,
@@ -250,16 +246,16 @@ def __init__(
250246
self.loss_fn = loss_fn
251247
self.cfg = cfg
252248
self.cp_normalize = cp_normalize
253-
254-
def __call__(self,
249+
250+
def __call__(
251+
self,
255252
data_dict: BatchedDataDict[Any],
256253
packed_seq_params: Optional[PackedSeqParams] = None,
257254
global_valid_seqs: Optional[torch.Tensor] = None,
258255
global_valid_toks: Optional[torch.Tensor] = None,
259256
) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, Any]]]:
260-
"""
261-
Create a loss post-processing function for training.
262-
257+
"""Create a loss post-processing function for training.
258+
263259
This function wraps a loss function with the necessary context and parameters
264260
to compute loss and metrics from model outputs. It handles sequence packing
265261
and context parallelism normalization.
@@ -274,7 +270,6 @@ def __call__(self,
274270
Returns:
275271
Callable: Function that takes output tensor and returns (loss, metrics) tuple
276272
"""
277-
278273
loss_fn = self.loss_fn
279274
pack_sequences = self.cfg["sequence_packing"]["enabled"]
280275
if pack_sequences and packed_seq_params is not None:
@@ -307,8 +302,8 @@ def _div_by_cp_size(*args, **kwargs):
307302

308303
return loss_fn_wrapped
309304

310-
class LogprobsPostProcessor:
311305

306+
class LogprobsPostProcessor:
312307
def __init__(self, cfg: Dict[str, Any]):
313308
self.cfg = cfg
314309

@@ -318,8 +313,7 @@ def __call__(
318313
input_ids: torch.Tensor,
319314
cu_seqlens_padded: torch.Tensor,
320315
) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
321-
"""
322-
Create a post-processing function that computes token log probabilities.
316+
"""Create a post-processing function that computes token log probabilities.
323317
324318
This function returns a processor that takes model logits and converts them
325319
to token-level log probabilities, handling both packed and unpacked sequences.
@@ -370,11 +364,11 @@ def processor_fn_inner(output_tensor):
370364
return torch.tensor(0.0, device=token_logprobs.device), {
371365
"logprobs": token_logprobs
372366
}
367+
373368
return processor_fn_inner
374369

375370

376371
class TopkLogitsPostProcessor:
377-
378372
def __init__(self, cfg: Dict[str, Any], k: int):
379373
self.cfg = cfg
380374
self.k = k
@@ -384,8 +378,7 @@ def __call__(
384378
data_dict: BatchedDataDict[Any],
385379
cu_seqlens_padded: torch.Tensor,
386380
) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
387-
"""
388-
Create a post-processing function that computes top-k logits and indices.
381+
"""Create a post-processing function that computes top-k logits and indices.
389382
390383
This function returns a processor that extracts the top-k highest logits
391384
and their corresponding vocabulary indices from model outputs. It handles
@@ -396,10 +389,9 @@ def __call__(
396389
cu_seqlens_padded: Cumulative sequence lengths for packed sequences
397390
398391
Returns:
399-
Callable: Function that takes output tensor and returns
392+
Callable: Function that takes output tensor and returns
400393
(dummy_loss, {"topk_logits": values, "topk_indices": indices})
401394
"""
402-
403395
pack = self.cfg["sequence_packing"]["enabled"]
404396
cp_size = self.cfg["megatron_cfg"]["context_parallel_size"]
405397
unpacked_seqlen = data_dict["input_ids"].shape[1]
@@ -521,4 +513,5 @@ def processor_fn_inner(output_tensor):
521513
"topk_logits": topk_vals_full,
522514
"topk_indices": topk_idx_full,
523515
}
524-
return processor_fn_inner
516+
517+
return processor_fn_inner

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@
9595
)
9696
from nemo_rl.models.generation.vllm.config import VllmConfig
9797
from nemo_rl.models.megatron.common import (
98-
broadcast_tensor,
99-
forward_step_arbitrary_loss,
10098
get_moe_metrics,
10199
)
102100
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
@@ -1018,7 +1016,9 @@ def train(
10181016

10191017
# Broadcast loss metrics from last stage to all stages
10201018
## TODO: check with PP > 1
1021-
gb_loss_metrics = broadcast_loss_metrics_from_last_stage(gb_loss_metrics)
1019+
gb_loss_metrics = broadcast_loss_metrics_from_last_stage(
1020+
gb_loss_metrics
1021+
)
10221022
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
10231023
mb_losses = [x["loss"] for x in gb_loss_metrics]
10241024

tests/unit/algorithms/test_sequence_packing_gradients.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def __init__(self, cp_size):
4141

4242
def test_sequence_packing_gradients(self):
4343
from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank
44-
from nemo_rl.models.megatron.train import (
45-
forward_with_post_processing_fn,
46-
LossPostProcessor,
47-
)
4844
from nemo_rl.models.megatron.data import (
4945
_pack_sequences_for_megatron,
5046
make_processed_microbatch_iterator,
5147
)
48+
from nemo_rl.models.megatron.train import (
49+
LossPostProcessor,
50+
forward_with_post_processing_fn,
51+
)
5252

5353
# Initialize process group
5454
torch.distributed.init_process_group(backend="nccl")

0 commit comments

Comments
 (0)