|
23 | 23 | reduce_aux_losses_tracker_across_ranks, |
24 | 24 | ) |
25 | 25 |
|
| 26 | +from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper |
| 27 | +from nemo_rl.distributed.batched_data_dict import BatchedDataDict |
| 28 | + |
| 29 | + |
26 | 30 | def _round_up_to_multiple(value: int, multiple: int) -> int: |
27 | 31 | return ( |
28 | 32 | ((value + multiple - 1) // multiple * multiple) |
29 | 33 | if value % multiple != 0 |
30 | 34 | else value |
31 | 35 | ) |
32 | 36 |
|
| 37 | + |
| 38 | +def forward_step_arbitrary_loss( |
| 39 | + state: GlobalState, |
| 40 | + global_valid_seqs: torch.Tensor, |
| 41 | + global_valid_toks: torch.Tensor, |
| 42 | + data_iterator: Iterator[BatchedDataDict[Any]], |
| 43 | + model: GPTModel, |
| 44 | + loss_fn: LossFunction, |
| 45 | + pack_sequences: bool = False, |
| 46 | + defer_fp32_logits: Optional[bool] = None, |
| 47 | + cp_normalize: bool = True, |
| 48 | + policy_cfg: Optional[dict] = None, |
| 49 | +): |
| 50 | + """Forward training step with support for packed sequences and context parallelism. |
| 51 | +
|
| 52 | + Args: |
| 53 | + state (GlobalState): Global state for the run |
| 54 | + global_valid_seqs: Global count of valid sequences |
| 55 | + global_valid_toks: Global count of valid tokens |
| 56 | + data_iterator: Input data iterator |
| 57 | + model (GPTModel): The GPT Model |
| 58 | + loss_fn (LossFunction): Loss function to apply |
| 59 | + pack_sequences (bool): Whether to pack sequences for efficiency |
| 60 | + defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 |
| 61 | + cp_normalize (bool): Whether to normalize the loss by the cp_size |
| 62 | + policy_cfg (Optional[dict]): Policy configuration containing generation parameters |
| 63 | +
|
| 64 | + Notes on packed sequences with context parallelism (CP): |
| 65 | + - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) |
| 66 | + - The factor of 2 ensures load balancing for causal attention |
| 67 | + - cu_seqlens tracks actual sequence boundaries |
| 68 | + - cu_seqlens_padded tracks padded sequence boundaries for CP |
| 69 | + - Requires TransformerEngine >= 1.10 for CP support |
| 70 | + """ |
| 71 | + straggler_timer = state.straggler_timer |
| 72 | + |
| 73 | + # Get the pre-processed microbatch from the iterator |
| 74 | + processed_mb = next(data_iterator) |
| 75 | + |
| 76 | + # Extract the processed components |
| 77 | + data_dict = processed_mb.data_dict |
| 78 | + input_ids = processed_mb.input_ids |
| 79 | + input_ids_cp_sharded = processed_mb.input_ids_cp_sharded |
| 80 | + attention_mask = processed_mb.attention_mask |
| 81 | + position_ids = processed_mb.position_ids |
| 82 | + packed_seq_params = processed_mb.packed_seq_params |
| 83 | + cu_seqlens_padded = processed_mb.cu_seqlens_padded |
| 84 | + |
| 85 | + multimodal_data = data_dict.get_multimodal_dict( |
| 86 | + as_tensors=True, device=input_ids_cp_sharded.device |
| 87 | + ) |
| 88 | + if len(multimodal_data) > 0: |
| 89 | + position_ids = None |
| 90 | + |
| 91 | + additional_kwargs = {} |
| 92 | + # Mamba models currently do not support packed_seq_params |
| 93 | + if packed_seq_params is not None: |
| 94 | + additional_kwargs["packed_seq_params"] = packed_seq_params |
| 95 | + |
| 96 | + if defer_fp32_logits: |
| 97 | + additional_kwargs["fp32_output"] = False |
| 98 | + |
| 99 | + with straggler_timer: |
| 100 | + output_tensor = model( |
| 101 | + input_ids=input_ids_cp_sharded, |
| 102 | + position_ids=position_ids, |
| 103 | + attention_mask=attention_mask, |
| 104 | + **additional_kwargs, |
| 105 | + **multimodal_data, |
| 106 | + ) |
| 107 | + |
| 108 | + # Apply temperature scaling to logits for training |
| 109 | + # This matches the dtensor worker's _apply_temperature_scaling in the train method |
| 110 | + if ( |
| 111 | + policy_cfg is not None |
| 112 | + and "generation" in policy_cfg |
| 113 | + and policy_cfg["generation"] is not None |
| 114 | + ): |
| 115 | + output_tensor.div_(policy_cfg["generation"]["temperature"]) |
| 116 | + |
| 117 | + # Unpack the output tensor if we did packed sequences |
| 118 | + if pack_sequences and packed_seq_params is not None: |
| 119 | + # remove padding |
| 120 | + loss_fn = SequencePackingLossWrapper( |
| 121 | + loss_fn=loss_fn, |
| 122 | + cu_seqlens_q=packed_seq_params.cu_seqlens_q, |
| 123 | + cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, |
| 124 | + ) |
| 125 | + |
| 126 | + loss_data = data_dict |
| 127 | + |
| 128 | + loss_fn_wrapped = partial( |
| 129 | + loss_fn, |
| 130 | + data=loss_data, |
| 131 | + global_valid_seqs=global_valid_seqs, |
| 132 | + global_valid_toks=global_valid_toks, |
| 133 | + vocab_parallel_rank=get_tensor_model_parallel_rank(), |
| 134 | + vocab_parallel_group=get_tensor_model_parallel_group(), |
| 135 | + context_parallel_group=get_context_parallel_group(), |
| 136 | + ) |
| 137 | + |
| 138 | + if cp_normalize: |
| 139 | + cp_size = get_context_parallel_world_size() |
| 140 | + orig_loss_fn_wrapped = loss_fn_wrapped |
| 141 | + |
| 142 | + def _div_by_cp_size(*args, **kwargs): |
| 143 | + loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) |
| 144 | + return loss / cp_size, metrics |
| 145 | + |
| 146 | + loss_fn_wrapped = _div_by_cp_size |
| 147 | + |
| 148 | + return output_tensor, loss_fn_wrapped |
| 149 | + |
| 150 | + |
33 | 151 | def broadcast_tensor( |
34 | 152 | tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup |
35 | 153 | ) -> torch.Tensor: |
|
0 commit comments