1616
1717
1818_SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE = 2048
19+ _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD = 4096
1920
2021
2122def _next_power_of_two (x ):
@@ -115,8 +116,8 @@ def _selective_logprob_forward_torch(hidden, weight, targets, bias=None, tempera
115116 end = min (start + vocab_chunk_size , vocab_size )
116117 chunk_width = end - start
117118 weight_chunk = weight [start :end ].to (hidden .dtype )
118- torch .mm (hidden , weight_chunk .t (), out = mm_buf [:, :chunk_width ])
119119 logits_chunk = logits_buf [:, :chunk_width ]
120+ torch .mm (hidden , weight_chunk .t (), out = mm_buf [:, :chunk_width ])
120121 logits_chunk .copy_ (mm_buf [:, :chunk_width ])
121122 if bias is not None :
122123 logits_chunk .add_ (bias [start :end ].to (torch .float32 ))
@@ -138,6 +139,48 @@ def _selective_logprob_forward_torch(hidden, weight, targets, bias=None, tempera
138139 return target_logit - log_z , log_z
139140
140141
142+ def _selective_logprob_forward_autograd (hidden , weight , targets , bias = None , temperature = 1.0 , vocab_chunk_size = 2048 ):
143+ n_rows , _ = hidden .shape
144+ vocab_size , _ = weight .shape
145+ inv_t = 1.0 / temperature
146+
147+ if vocab_size <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD :
148+ logits = hidden @ weight .to (hidden .dtype ).t ()
149+ logits = logits .float ()
150+ if bias is not None :
151+ logits = logits + bias .to (torch .float32 )
152+ logits = logits * inv_t
153+ return torch .log_softmax (logits , dim = - 1 ).gather (- 1 , targets .unsqueeze (- 1 )).squeeze (- 1 )
154+
155+ max_old = torch .full ((n_rows ,), float ("-inf" ), device = hidden .device , dtype = torch .float32 )
156+ sum_exp = torch .zeros ((n_rows ,), device = hidden .device , dtype = torch .float32 )
157+ target_logit = torch .zeros ((n_rows ,), device = hidden .device , dtype = torch .float32 )
158+ row_idx = torch .arange (n_rows , device = hidden .device )
159+
160+ for start in range (0 , vocab_size , vocab_chunk_size ):
161+ end = min (start + vocab_chunk_size , vocab_size )
162+ logits_chunk = hidden @ weight [start :end ].to (hidden .dtype ).t ()
163+ logits_chunk = logits_chunk .float ()
164+ if bias is not None :
165+ logits_chunk = logits_chunk + bias [start :end ].to (torch .float32 )
166+ logits_chunk = logits_chunk * inv_t
167+
168+ chunk_max = logits_chunk .amax (dim = - 1 )
169+ max_new = torch .maximum (max_old , chunk_max )
170+ rescale = torch .exp (max_old - max_new )
171+ chunk_exp = torch .exp (logits_chunk - max_new .unsqueeze (- 1 ))
172+
173+ sum_exp = sum_exp * rescale + chunk_exp .sum (dim = - 1 )
174+ max_old = max_new
175+
176+ in_chunk = (targets >= start ) & (targets < end )
177+ local_idx = torch .clamp (targets - start , 0 , end - start - 1 )
178+ target_logit = target_logit + logits_chunk [row_idx , local_idx ] * in_chunk
179+
180+ log_z = max_old + torch .log (sum_exp )
181+ return target_logit - log_z
182+
183+
141184def _selective_logprob_forward_triton (hidden , weight , targets , bias = None , temperature = 1.0 , vocab_chunk_size = 2048 ):
142185 n_rows , hidden_size = hidden .shape
143186 block_h = min (128 , _next_power_of_two (hidden_size ))
@@ -171,18 +214,6 @@ def _selective_logprob_forward_triton(hidden, weight, targets, bias=None, temper
171214
172215
173216def _selective_logprob_forward (hidden , weight , targets , bias = None , temperature = 1.0 , vocab_chunk_size = 2048 ):
174- if (
175- _TRITON_AVAILABLE
176- and hidden .is_cuda
177- and weight .is_cuda
178- and targets .is_cuda
179- and hidden .is_contiguous ()
180- and weight .is_contiguous ()
181- and targets .is_contiguous ()
182- and (bias is None or (bias .is_cuda and bias .is_contiguous ()))
183- ):
184- return _selective_logprob_forward_triton (hidden , weight , targets , bias , temperature , vocab_chunk_size )
185-
186217 return _selective_logprob_forward_torch (hidden , weight , targets , bias , temperature , vocab_chunk_size )
187218
188219
@@ -226,12 +257,12 @@ def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logpr
226257 n_rows , _ = hidden .shape
227258 vocab_size = weight .shape [0 ]
228259 has_bias = bias is not None
260+ hidden_fp32 = hidden .float ()
229261
230262 grad_hidden = torch .zeros (hidden .shape , device = hidden .device , dtype = torch .float32 )
231263 grad_weight = torch .zeros (weight .shape , device = weight .device , dtype = torch .float32 )
232264 grad_bias = torch .zeros ((vocab_size ,), device = weight .device , dtype = torch .float32 ) if has_bias else None
233265
234- mm_buf = torch .empty ((n_rows , vocab_chunk_size ), device = hidden .device , dtype = hidden .dtype )
235266 logits_buf = torch .empty ((n_rows , vocab_chunk_size ), device = hidden .device , dtype = torch .float32 )
236267
237268 grad_logprobs = grad_logprobs .to (torch .float32 )
@@ -240,11 +271,16 @@ def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logpr
240271 for start in range (0 , vocab_size , vocab_chunk_size ):
241272 end = min (start + vocab_chunk_size , vocab_size )
242273 chunk_width = end - start
243- weight_chunk = weight [start :end ]
244-
245- torch .mm (hidden , weight_chunk .t (), out = mm_buf [:, :chunk_width ])
274+ weight_chunk = weight [start :end ].float ()
246275 logits_chunk = logits_buf [:, :chunk_width ]
247- logits_chunk .copy_ (mm_buf [:, :chunk_width ])
276+ allow_tf32 = torch .backends .cuda .matmul .allow_tf32
277+ if hidden .is_cuda :
278+ torch .backends .cuda .matmul .allow_tf32 = False
279+ try :
280+ torch .mm (hidden_fp32 , weight_chunk .t (), out = logits_chunk )
281+ finally :
282+ if hidden .is_cuda :
283+ torch .backends .cuda .matmul .allow_tf32 = allow_tf32
248284 if has_bias :
249285 logits_chunk .add_ (bias [start :end ].to (torch .float32 ))
250286 logits_chunk .mul_ (inv_t )
@@ -257,8 +293,15 @@ def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logpr
257293 grad_logits [row_idx , local_idx ] += grad_logprobs * in_chunk
258294 grad_logits .mul_ (inv_t )
259295
260- grad_hidden .add_ (grad_logits @ weight_chunk .float ())
261- grad_weight [start :end ].add_ (grad_logits .t () @ hidden .float ())
296+ allow_tf32 = torch .backends .cuda .matmul .allow_tf32
297+ if hidden .is_cuda :
298+ torch .backends .cuda .matmul .allow_tf32 = False
299+ try :
300+ grad_hidden .add_ (grad_logits @ weight_chunk )
301+ grad_weight [start :end ].add_ (grad_logits .t () @ hidden_fp32 )
302+ finally :
303+ if hidden .is_cuda :
304+ torch .backends .cuda .matmul .allow_tf32 = allow_tf32
262305 if has_bias :
263306 grad_bias [start :end ].add_ (grad_logits .sum (dim = 0 ))
264307
@@ -385,7 +428,8 @@ def forward(
385428 delta = delta ,
386429 use_bias_correction_kl = use_bias_correction_kl ,
387430 )
388- compiled_compute_loss = torch .compile (compute_loss ) if compiled else compute_loss
431+ use_compiled_compute_loss = compiled and weight .shape [0 ] > _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD
432+ compiled_compute_loss = torch .compile (compute_loss ) if use_compiled_compute_loss else compute_loss
389433
390434 def fused_fwd_bwd (
391435 input_chunk ,
@@ -463,7 +507,10 @@ def accumulate_chunk(
463507 aggregated_metrics [i ].append (metric )
464508
465509 # Process input in chunks based on chunk_size
466- chunks = max (1 , _input .shape [0 ] // chunk_size )
510+ if weight .shape [0 ] <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD :
511+ chunks = 1
512+ else :
513+ chunks = max (1 , _input .shape [0 ] // chunk_size )
467514 _input_chunks = torch .chunk (_input , chunks = chunks , dim = 0 )
468515 _selected_token_ids_chunks = torch .chunk (selected_token_ids , chunks = chunks , dim = 0 )
469516 _attention_mask_chunks = torch .chunk (attention_mask , chunks = chunks , dim = 0 )
@@ -598,13 +645,24 @@ def _compute_chunk_loss(
598645
599646 if use_ref_model and ref_per_token_logps_chunk is None :
600647 with torch .no_grad ():
601- ref_per_token_logps_chunk = LigerFusedLinearPPOBase .chunk_forward (
602- ref_input_chunk ,
603- ref_weight ,
604- selected_token_ids_chunk ,
605- bias = ref_bias ,
606- temperature = temperature ,
607- )
648+ if ref_weight .shape [0 ] <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD :
649+ ref_logits = ref_input_chunk @ ref_weight .t ()
650+ if ref_bias is not None :
651+ ref_logits = ref_logits + ref_bias .float ()
652+ ref_logits = ref_logits .float ()
653+ if temperature != 1.0 :
654+ ref_logits = ref_logits / temperature
655+ ref_per_token_logps_chunk = torch .log_softmax (ref_logits , dim = - 1 ).gather (
656+ - 1 , selected_token_ids_chunk .unsqueeze (- 1 )
657+ ).squeeze (- 1 )
658+ else :
659+ ref_per_token_logps_chunk = LigerFusedLinearPPOBase .chunk_forward (
660+ ref_input_chunk ,
661+ ref_weight ,
662+ selected_token_ids_chunk ,
663+ bias = ref_bias ,
664+ temperature = temperature ,
665+ )
608666
609667 # Compute chunk loss and metrics using the provided loss function
610668 chunk_loss , chunk_metrics = ppo_loss_fn (
@@ -632,16 +690,20 @@ def _compute_chunk_loss(
632690 @staticmethod
633691 def chunk_forward (input_chunk , weight , selected_token_ids , bias = None , temperature = 1.0 ):
634692 """Compute selected-token log probabilities without materializing full vocab logits."""
693+ if weight .shape [0 ] <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD :
694+ logits = input_chunk @ weight .t ()
695+ if bias is not None :
696+ logits = logits + bias
697+ logits = logits .float ()
698+ if temperature != 1.0 :
699+ logits = logits / temperature
700+ return torch .log_softmax (logits , dim = - 1 ).gather (- 1 , selected_token_ids .unsqueeze (- 1 )).squeeze (- 1 )
701+
635702 batch_size , seq_len , hidden_size = input_chunk .shape
636703 hidden = input_chunk .reshape (batch_size * seq_len , hidden_size ).contiguous ()
637704 targets = selected_token_ids .reshape (batch_size * seq_len ).contiguous ()
638- per_token_logps = _ChunkedSelectiveLogProbFunction .apply (
639- hidden ,
640- weight ,
641- targets ,
642- bias ,
643- temperature ,
644- _SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE ,
705+ per_token_logps = _selective_logprob_forward_autograd (
706+ hidden , weight , targets , bias , temperature , _SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE
645707 )
646708 return per_token_logps .reshape (batch_size , seq_len )
647709
0 commit comments