1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from functools import partial
15- from typing import Any , Iterator , Optional
14+
15+ from typing import Any , Optional
1616
1717import torch
1818import torch .distributed as dist
19- from megatron .bridge .training .state import GlobalState
20- from megatron .core .models .gpt import GPTModel
21- from megatron .core .parallel_state import (
22- get_context_parallel_group ,
23- get_context_parallel_world_size ,
24- get_tensor_model_parallel_group ,
25- get_tensor_model_parallel_rank ,
26- )
2719from megatron .core .transformer .moe .moe_utils import (
2820 clear_aux_losses_tracker ,
2921 get_moe_layer_wise_logging_tracker ,
3022 reduce_aux_losses_tracker_across_ranks ,
3123)
3224
33- from nemo_rl .algorithms .loss_functions import LossFunction , SequencePackingLossWrapper
34- from nemo_rl .distributed .batched_data_dict import BatchedDataDict
35-
3625
3726def _round_up_to_multiple (value : int , multiple : int ) -> int :
3827 return (
@@ -42,119 +31,6 @@ def _round_up_to_multiple(value: int, multiple: int) -> int:
4231 )
4332
4433
45- def forward_step_arbitrary_loss (
46- state : GlobalState ,
47- global_valid_seqs : torch .Tensor ,
48- global_valid_toks : torch .Tensor ,
49- data_iterator : Iterator [BatchedDataDict [Any ]],
50- model : GPTModel ,
51- loss_fn : LossFunction ,
52- pack_sequences : bool = False ,
53- defer_fp32_logits : Optional [bool ] = None ,
54- cp_normalize : bool = True ,
55- policy_cfg : Optional [dict ] = None ,
56- ):
57- """Forward training step with support for packed sequences and context parallelism.
58-
59- Args:
60- state (GlobalState): Global state for the run
61- global_valid_seqs: Global count of valid sequences
62- global_valid_toks: Global count of valid tokens
63- data_iterator: Input data iterator
64- model (GPTModel): The GPT Model
65- loss_fn (LossFunction): Loss function to apply
66- pack_sequences (bool): Whether to pack sequences for efficiency
67- defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
68- cp_normalize (bool): Whether to normalize the loss by the cp_size
69- policy_cfg (Optional[dict]): Policy configuration containing generation parameters
70-
71- Notes on packed sequences with context parallelism (CP):
72- - When CP > 1, each sequence is padded to a multiple of (cp_size * 2)
73- - The factor of 2 ensures load balancing for causal attention
74- - cu_seqlens tracks actual sequence boundaries
75- - cu_seqlens_padded tracks padded sequence boundaries for CP
76- - Requires TransformerEngine >= 1.10 for CP support
77- """
78- straggler_timer = state .straggler_timer
79-
80- # Get the pre-processed microbatch from the iterator
81- processed_mb = next (data_iterator )
82-
83- # Extract the processed components
84- data_dict = processed_mb .data_dict
85- input_ids = processed_mb .input_ids
86- input_ids_cp_sharded = processed_mb .input_ids_cp_sharded
87- attention_mask = processed_mb .attention_mask
88- position_ids = processed_mb .position_ids
89- packed_seq_params = processed_mb .packed_seq_params
90- cu_seqlens_padded = processed_mb .cu_seqlens_padded
91-
92- multimodal_data = data_dict .get_multimodal_dict (
93- as_tensors = True , device = input_ids_cp_sharded .device
94- )
95- if len (multimodal_data ) > 0 :
96- position_ids = None
97-
98- additional_kwargs = {}
99- # Mamba models currently do not support packed_seq_params
100- if packed_seq_params is not None :
101- additional_kwargs ["packed_seq_params" ] = packed_seq_params
102-
103- if defer_fp32_logits :
104- additional_kwargs ["fp32_output" ] = False
105-
106- with straggler_timer :
107- output_tensor = model (
108- input_ids = input_ids_cp_sharded ,
109- position_ids = position_ids ,
110- attention_mask = attention_mask ,
111- ** additional_kwargs ,
112- ** multimodal_data ,
113- )
114-
115- # Apply temperature scaling to logits for training
116- # This matches the dtensor worker's _apply_temperature_scaling in the train method
117- if (
118- policy_cfg is not None
119- and "generation" in policy_cfg
120- and policy_cfg ["generation" ] is not None
121- ):
122- output_tensor .div_ (policy_cfg ["generation" ]["temperature" ])
123-
124- # Unpack the output tensor if we did packed sequences
125- if pack_sequences and packed_seq_params is not None :
126- # remove padding
127- loss_fn = SequencePackingLossWrapper (
128- loss_fn = loss_fn ,
129- cu_seqlens_q = packed_seq_params .cu_seqlens_q ,
130- cu_seqlens_q_padded = packed_seq_params .cu_seqlens_q_padded ,
131- )
132-
133- loss_data = data_dict
134-
135- loss_fn_wrapped = partial (
136- loss_fn ,
137- data = loss_data ,
138- global_valid_seqs = global_valid_seqs ,
139- global_valid_toks = global_valid_toks ,
140- vocab_parallel_rank = get_tensor_model_parallel_rank (),
141- vocab_parallel_group = get_tensor_model_parallel_group (),
142- context_parallel_group = get_context_parallel_group (),
143- )
144-
145- if cp_normalize :
146- cp_size = get_context_parallel_world_size ()
147- orig_loss_fn_wrapped = loss_fn_wrapped
148-
149- def _div_by_cp_size (* args , ** kwargs ):
150- loss , metrics = orig_loss_fn_wrapped (* args , ** kwargs )
151- return loss / cp_size , metrics
152-
153- loss_fn_wrapped = _div_by_cp_size
154-
155- return output_tensor , loss_fn_wrapped
156-
157-
15834def broadcast_tensor (
15935 tensor : torch .Tensor | None , src_rank : int , group : dist .ProcessGroup
16036) -> torch .Tensor :
0 commit comments