@@ -73,6 +73,7 @@ def test_processed_microbatch_fields(self):
7373 assert microbatch .packed_seq_params == mock_packed_seq_params
7474 assert torch .equal (microbatch .cu_seqlens_padded , mock_cu_seqlens_padded )
7575
76+
7677class TestCheckSequenceDim :
7778 """Tests for check_sequence_dim function."""
7879
@@ -154,7 +155,9 @@ def test_process_microbatch_no_packing(self, mock_get_masks):
154155
155156 # Create test data
156157 data_dict = MagicMock ()
157- input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 0 , 0 , 0 , 0 , 0 ], [6 , 7 , 8 , 9 , 10 , 11 , 12 , 0 , 0 , 0 ]])
158+ input_ids = torch .tensor (
159+ [[1 , 2 , 3 , 4 , 5 , 0 , 0 , 0 , 0 , 0 ], [6 , 7 , 8 , 9 , 10 , 11 , 12 , 0 , 0 , 0 ]]
160+ )
158161 data_dict .__getitem__ = MagicMock (return_value = input_ids )
159162
160163 (
@@ -178,7 +181,9 @@ def test_process_microbatch_no_packing(self, mock_get_masks):
178181 mock_get_masks .assert_called_once ()
179182
180183 @patch ("nemo_rl.models.megatron.data.get_context_parallel_rank" , return_value = 0 )
181- @patch ("nemo_rl.models.megatron.data.get_context_parallel_world_size" , return_value = 1 )
184+ @patch (
185+ "nemo_rl.models.megatron.data.get_context_parallel_world_size" , return_value = 1
186+ )
182187 @patch ("nemo_rl.models.megatron.data._pack_sequences_for_megatron" )
183188 def test_process_microbatch_with_packing (
184189 self , mock_pack , mock_cp_world , mock_cp_rank
@@ -226,7 +231,7 @@ def test_process_microbatch_with_packing(
226231 assert attention_mask is None
227232 assert position_ids is None
228233 assert cu_seqlens_padded is not None
229-
234+
230235 # Verify pack was called
231236 mock_pack .assert_called_once ()
232237
@@ -323,6 +328,7 @@ def test_process_global_batch_requires_sample_mask(self):
323328
324329 assert "sample_mask must be present" in str (exc_info .value )
325330
331+
326332class TestGetMicrobatchIterator :
327333 """Tests for get_microbatch_iterator function."""
328334
@@ -383,8 +389,13 @@ def test_get_microbatch_iterator_sequence_packing(
383389 mock_make_iterator .return_value = mock_iterator
384390
385391 mock_data = MagicMock ()
386- mock_data .make_microbatch_iterator_for_packable_sequences .return_value = iter ([])
387- mock_data .get_microbatch_iterator_for_packable_sequences_len .return_value = (10 , 512 )
392+ mock_data .make_microbatch_iterator_for_packable_sequences .return_value = iter (
393+ []
394+ )
395+ mock_data .get_microbatch_iterator_for_packable_sequences_len .return_value = (
396+ 10 ,
397+ 512 ,
398+ )
388399
389400 cfg = {
390401 "dynamic_batching" : {"enabled" : False },
@@ -473,8 +484,13 @@ def test_get_microbatch_iterator_auto_detects_seq_length_key(
473484 mock_make_iterator .return_value = mock_iterator
474485
475486 mock_data = MagicMock ()
476- mock_data .make_microbatch_iterator_for_packable_sequences .return_value = iter ([])
477- mock_data .get_microbatch_iterator_for_packable_sequences_len .return_value = (5 , 256 )
487+ mock_data .make_microbatch_iterator_for_packable_sequences .return_value = iter (
488+ []
489+ )
490+ mock_data .get_microbatch_iterator_for_packable_sequences_len .return_value = (
491+ 5 ,
492+ 256 ,
493+ )
478494
479495 cfg = {
480496 "dynamic_batching" : {"enabled" : False },
@@ -1677,4 +1693,3 @@ def test_get_pack_sequence_parameters_for_megatron(get_pack_sequence_parameters_
16771693 # Check that all workers succeeded
16781694 for i , result in enumerate (results ):
16791695 assert result ["success" ], f"Worker { i } failed: { result ['error' ]} "
1680-
0 commit comments