File tree Expand file tree Collapse file tree 3 files changed +3
-2
lines changed
tests/unit/models/megatron Expand file tree Collapse file tree 3 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -58,6 +58,7 @@ class ProcessedMicrobatch:
5858 packed_seq_params: PackedSeqParams for sequence packing (None if not packing)
5959 cu_seqlens_padded: Padded cumulative sequence lengths (None if not packing)
6060 """
61+
6162 data_dict : BatchedDataDict [Any ]
6263 input_ids : torch .Tensor
6364 input_ids_cp_sharded : torch .Tensor
@@ -202,6 +203,7 @@ def get_microbatch_iterator(
202203 padded_seq_length ,
203204 )
204205
206+
205207def process_microbatch (
206208 data_dict : BatchedDataDict [Any ],
207209 seq_length_key : Optional [str ] = None ,
@@ -338,6 +340,7 @@ def process_global_batch(
338340 }
339341
340342
343+
341344def _pack_sequences_for_megatron (
342345 input_ids : torch .Tensor ,
343346 seq_lengths : torch .Tensor ,
Original file line number Diff line number Diff line change 8585 LogprobsPostProcessor ,
8686 TopkLogitsPostProcessor ,
8787)
88- from nemo_rl .models .megatron .community_import import import_model_from_hf_name
8988from nemo_rl .models .policy import PolicyConfig
9089from nemo_rl .models .policy .interfaces import (
9190 ColocatablePolicyInterface ,
Original file line number Diff line number Diff line change @@ -1709,4 +1709,3 @@ def test_get_pack_sequence_parameters_for_megatron(get_pack_sequence_parameters_
17091709 # Check that all workers succeeded
17101710 for i , result in enumerate (results ):
17111711 assert result ["success" ], f"Worker { i } failed: { result ['error' ]} "
1712-
You can’t perform that action at this time.
0 commit comments