Skip to content

Commit a2a8a51

Browse files
committed
lint
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent ea02f64 commit a2a8a51

File tree

5 files changed

+37
-19
lines changed

5 files changed

+37
-19
lines changed

nemo_rl/models/megatron/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_moe_layer_wise_logging_tracker,
3030
reduce_aux_losses_tracker_across_ranks,
3131
)
32+
3233
from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper
3334
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
3435

@@ -40,6 +41,7 @@ def _round_up_to_multiple(value: int, multiple: int) -> int:
4041
else value
4142
)
4243

44+
4345
def forward_step_arbitrary_loss(
4446
state: GlobalState,
4547
global_valid_seqs: torch.Tensor,

nemo_rl/models/megatron/data.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
from typing import Any, Iterator, Optional, Tuple
1717

1818
import torch
19-
2019
from megatron.core.packed_seq_params import PackedSeqParams
2120
from megatron.core.parallel_state import (
2221
get_context_parallel_rank,
2322
get_context_parallel_world_size,
2423
)
2524
from megatron.training.utils import get_ltor_masks_and_position_ids
26-
from nemo_rl.models.megatron.common import _round_up_to_multiple
25+
2726
from nemo_rl.algorithms.interfaces import LossFunction, LossType
2827
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2928
from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank
29+
from nemo_rl.models.megatron.common import _round_up_to_multiple
3030

3131

3232
@dataclass
@@ -45,6 +45,7 @@ class ProcessedMicrobatch:
4545
packed_seq_params: PackedSeqParams for sequence packing (None if not packing)
4646
cu_seqlens_padded: Padded cumulative sequence lengths (None if not packing)
4747
"""
48+
4849
data_dict: BatchedDataDict[Any]
4950
input_ids: torch.Tensor
5051
input_ids_cp_sharded: torch.Tensor
@@ -192,6 +193,7 @@ def get_microbatch_iterator(
192193
padded_seq_length,
193194
)
194195

196+
195197
def process_microbatch(
196198
data_dict: BatchedDataDict[Any],
197199
seq_length_key: Optional[str] = None,
@@ -200,7 +202,7 @@ def process_microbatch(
200202
pad_full_seq_to: Optional[int] = None,
201203
pack_sequences: bool = False,
202204
):
203-
#with straggler_timer(bdata=True):
205+
# with straggler_timer(bdata=True):
204206
input_ids = data_dict["input_ids"]
205207
attention_mask = None
206208
position_ids = None
@@ -217,9 +219,7 @@ def process_microbatch(
217219
assert seq_length_key is not None, (
218220
"seq_length_key must be provided for packed sequences"
219221
)
220-
assert seq_length_key in data_dict, (
221-
f"{seq_length_key} not found in data_dict"
222-
)
222+
assert seq_length_key in data_dict, f"{seq_length_key} not found in data_dict"
223223

224224
# Get sequence lengths and context parallel size
225225
seq_lengths = data_dict[seq_length_key]
@@ -240,7 +240,7 @@ def process_microbatch(
240240
cp_rank=get_context_parallel_rank(),
241241
cp_size=get_context_parallel_world_size(),
242242
)
243-
243+
244244
# For packed sequences, position_ids and attention_mask are typically None
245245
# The PackedSeqParams handles all necessary sequence information
246246
position_ids = None
@@ -265,6 +265,7 @@ def process_microbatch(
265265
cu_seqlens_padded,
266266
)
267267

268+
268269
def process_global_batch(
269270
data: BatchedDataDict[Any],
270271
batch_idx: int,
@@ -301,6 +302,7 @@ def process_global_batch(
301302
global_valid_toks,
302303
)
303304

305+
304306
def _pack_sequences_for_megatron(
305307
input_ids: torch.Tensor,
306308
seq_lengths: torch.Tensor,
@@ -605,6 +607,7 @@ def _unpack_sequences_from_megatron(
605607

606608
return unpacked_output
607609

610+
608611
def check_sequence_dim(data: BatchedDataDict[Any]):
609612
# dim 1 is always assumed to be the sequence dim, sanity check this here
610613
sequence_dim = 1
@@ -614,4 +617,4 @@ def check_sequence_dim(data: BatchedDataDict[Any]):
614617
assert v.shape[sequence_dim] == seq_dim_size, (
615618
f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}"
616619
)
617-
return sequence_dim, seq_dim_size
620+
return sequence_dim, seq_dim_size

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@
113113
forward_step_arbitrary_loss,
114114
get_moe_metrics,
115115
)
116+
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
116117
from nemo_rl.models.megatron.data import (
117118
get_microbatch_iterator,
118119
process_global_batch,
119120
)
120-
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
121121
from nemo_rl.models.policy import PolicyConfig
122122
from nemo_rl.models.policy.interfaces import (
123123
ColocatablePolicyInterface,

tests/unit/algorithms/test_sequence_packing_gradients.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
339339
"pipeline_model_parallel_size": 1,
340340
"context_parallel_size": cp_size,
341341
},
342-
343-
344342
},
345343
seq_length_key="input_lengths",
346344
pad_individual_seqs_to_multiple_of=pad_to_multiple,

tests/unit/models/megatron/test_megatron_data.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_processed_microbatch_fields(self):
7373
assert microbatch.packed_seq_params == mock_packed_seq_params
7474
assert torch.equal(microbatch.cu_seqlens_padded, mock_cu_seqlens_padded)
7575

76+
7677
class TestCheckSequenceDim:
7778
"""Tests for check_sequence_dim function."""
7879

@@ -154,7 +155,9 @@ def test_process_microbatch_no_packing(self, mock_get_masks):
154155

155156
# Create test data
156157
data_dict = MagicMock()
157-
input_ids = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0], [6, 7, 8, 9, 10, 11, 12, 0, 0, 0]])
158+
input_ids = torch.tensor(
159+
[[1, 2, 3, 4, 5, 0, 0, 0, 0, 0], [6, 7, 8, 9, 10, 11, 12, 0, 0, 0]]
160+
)
158161
data_dict.__getitem__ = MagicMock(return_value=input_ids)
159162

160163
(
@@ -178,7 +181,9 @@ def test_process_microbatch_no_packing(self, mock_get_masks):
178181
mock_get_masks.assert_called_once()
179182

180183
@patch("nemo_rl.models.megatron.data.get_context_parallel_rank", return_value=0)
181-
@patch("nemo_rl.models.megatron.data.get_context_parallel_world_size", return_value=1)
184+
@patch(
185+
"nemo_rl.models.megatron.data.get_context_parallel_world_size", return_value=1
186+
)
182187
@patch("nemo_rl.models.megatron.data._pack_sequences_for_megatron")
183188
def test_process_microbatch_with_packing(
184189
self, mock_pack, mock_cp_world, mock_cp_rank
@@ -226,7 +231,7 @@ def test_process_microbatch_with_packing(
226231
assert attention_mask is None
227232
assert position_ids is None
228233
assert cu_seqlens_padded is not None
229-
234+
230235
# Verify pack was called
231236
mock_pack.assert_called_once()
232237

@@ -323,6 +328,7 @@ def test_process_global_batch_requires_sample_mask(self):
323328

324329
assert "sample_mask must be present" in str(exc_info.value)
325330

331+
326332
class TestGetMicrobatchIterator:
327333
"""Tests for get_microbatch_iterator function."""
328334

@@ -383,8 +389,13 @@ def test_get_microbatch_iterator_sequence_packing(
383389
mock_make_iterator.return_value = mock_iterator
384390

385391
mock_data = MagicMock()
386-
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter([])
387-
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (10, 512)
392+
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter(
393+
[]
394+
)
395+
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (
396+
10,
397+
512,
398+
)
388399

389400
cfg = {
390401
"dynamic_batching": {"enabled": False},
@@ -473,8 +484,13 @@ def test_get_microbatch_iterator_auto_detects_seq_length_key(
473484
mock_make_iterator.return_value = mock_iterator
474485

475486
mock_data = MagicMock()
476-
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter([])
477-
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (5, 256)
487+
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter(
488+
[]
489+
)
490+
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (
491+
5,
492+
256,
493+
)
478494

479495
cfg = {
480496
"dynamic_batching": {"enabled": False},
@@ -1677,4 +1693,3 @@ def test_get_pack_sequence_parameters_for_megatron(get_pack_sequence_parameters_
16771693
# Check that all workers succeeded
16781694
for i, result in enumerate(results):
16791695
assert result["success"], f"Worker {i} failed: {result['error']}"
1680-

0 commit comments

Comments
 (0)