1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import math
15- from typing import Any , NotRequired , Optional , TypedDict , TypeVar
15+ from typing import Any , Callable , NotRequired , Optional , TypedDict , TypeVar
1616
1717import torch
1818import torch .distributed
@@ -802,22 +802,27 @@ class SequencePackingLossWrapper:
802802 def __init__ (
803803 self ,
804804 loss_fn : LossFunction ,
805+ prepare_fn : Callable [Any , Any ],
805806 cu_seqlens_q : Tensor ,
806807 cu_seqlens_q_padded : Optional [Tensor ] = None ,
808+ vocab_parallel_rank : Optional [int ] = None ,
809+ vocab_parallel_group : Optional [torch .distributed .ProcessGroup ] = None ,
810+ context_parallel_group : Optional [torch .distributed .ProcessGroup ] = None ,
807811 ):
808812 self .loss_fn = loss_fn
813+ self .prepare_fn = prepare_fn
809814 self .cu_seqlens_q = cu_seqlens_q
810815 self .cu_seqlens_q_padded = cu_seqlens_q_padded
816+ self .vocab_parallel_rank = vocab_parallel_rank
817+ self .vocab_parallel_group = vocab_parallel_group
818+ self .context_parallel_group = context_parallel_group
811819
812820 def __call__ (
813821 self ,
814822 next_token_logits : Tensor ,
815823 data : BatchedDataDict [Any ],
816824 global_valid_seqs : Tensor | None ,
817825 global_valid_toks : Tensor | None ,
818- vocab_parallel_rank : Optional [int ] = None ,
819- vocab_parallel_group : Optional [torch .distributed .ProcessGroup ] = None ,
820- context_parallel_group : Optional [torch .distributed .ProcessGroup ] = None ,
821826 ) -> tuple [Tensor , dict [str , Any ]]:
822827 """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding."""
823828 unpadded_cu_seqlens = self .cu_seqlens_q
@@ -851,8 +856,8 @@ def __call__(
851856 # get next_token_logits
852857 cp_size = (
853858 1
854- if context_parallel_group is None
855- else torch .distributed .get_world_size (context_parallel_group )
859+ if self . context_parallel_group is None
860+ else torch .distributed .get_world_size (self . context_parallel_group )
856861 )
857862 logit_start = seq_start // cp_size
858863 logit_end = (seq_start + padded_seq_lengths [seq_idx ]) // cp_size
@@ -861,14 +866,14 @@ def __call__(
861866 1 , logit_start , logit_length
862867 )
863868
869+ # prepare data for loss function
870+ loss_fn_args = self .prepare_fn (next_token_logits_slice , unpadded_seq_data )
871+
864872 loss , metrics = self .loss_fn (
865- next_token_logits_slice ,
873+ * loss_fn_args ,
866874 unpadded_seq_data ,
867875 global_valid_seqs ,
868876 global_valid_toks ,
869- vocab_parallel_rank = vocab_parallel_rank ,
870- vocab_parallel_group = vocab_parallel_group ,
871- context_parallel_group = context_parallel_group ,
872877 )
873878 loss_accum += loss
874879 for k , v in metrics .items ():
0 commit comments