@@ -84,11 +84,14 @@ def _kldiv_kernel_forward(
8484def _kldiv_kernel_backward (
8585 target_ptr ,
8686 new_grads_ptr ,
87+ grad_output_ptr ,
8788 n_rows ,
8889 n_cols ,
8990 BLOCK_SIZE_M : tl .constexpr ,
9091 BLOCK_SIZE_N : tl .constexpr ,
9192 log_target : tl .constexpr = False ,
93+ reduction : tl .constexpr = _REDUCTION_MODE_BATCHMEAN ,
94+ has_grad_output : tl .constexpr = False ,
9295):
9396 pid = tl .program_id (0 )
9497 num_progs = tl .num_programs (0 )
@@ -97,6 +100,10 @@ def _kldiv_kernel_backward(
97100 grid_n = tl .cdiv (n_cols , BLOCK_SIZE_N )
98101 total_2d_blocks = grid_m * grid_n
99102
103+ # For reduced losses, grad_output is a scalar. Load it once per program.
104+ if not has_grad_output :
105+ grad_output_scalar = tl .load (grad_output_ptr )
106+
100107 # Persistent-program loop over logical 2D blocks.
101108 for block_idx in tl .range (pid , total_2d_blocks , num_progs ):
102109 block_m = block_idx // grid_n
@@ -118,6 +125,17 @@ def _kldiv_kernel_backward(
118125 else :
119126 res = y_true * - 1
120127
128+ if not has_grad_output :
129+ res = res * grad_output_scalar
130+ else :
131+ grad_output = tl .load (grad_output_ptr + offset , mask = mask , other = 0.0 )
132+ res = res * grad_output
133+
134+ if reduction == _REDUCTION_MODE_BATCHMEAN :
135+ res = res / n_rows
136+ elif reduction == _REDUCTION_MODE_MEAN :
137+ res = res / (n_rows * n_cols )
138+
121139 tl .store (new_grads_ptr + offset , res , mask = mask )
122140
123141
@@ -126,13 +144,24 @@ def _kldiv_kernel_backward(
126144# -----------------------------------------------------------------------------
127145
128146
129- def get_optimal_block_size (n_rows , dtype_size , BLOCK_SIZE_N : tl .constexpr , is_backward : bool = False ):
147+ def get_optimal_block_size (
148+ n_rows ,
149+ dtype_size ,
150+ BLOCK_SIZE_N : tl .constexpr ,
151+ is_backward : bool = False ,
152+ needs_grad_output_tile : bool = False ,
153+ ):
130154 """
131155 Calculate optimal BLOCK_SIZE_M using compute_default_tiling_strategy.
132156 """
133157 # 1) Set memory multiplier
134- # Backward is lighter than forward in this op, so use a smaller multiplier.
135- multiplier = 2.5 if is_backward else 3.0
158+ # Backward is lighter than forward in this op, so we typically use a smaller multiplier.
159+ # If backward also needs to stream a full grad_output tile (i.e., grad_output is not a scalar),
160+ # its memory footprint becomes closer to forward, so we bump the multiplier.
161+ if is_backward :
162+ multiplier = 3.0 if needs_grad_output_tile else 2.5
163+ else :
164+ multiplier = 3.0
136165
137166 # For bf16/fp16 (dtype_size < 4), compile-time UB overflow was observed on some shapes.
138167 # Clamp to fp32 size for a conservative tiling estimate; this can be refined later.
@@ -199,35 +228,35 @@ def kldiv_backward_triton(target, grad_output, new_grads, log_target, reduction)
199228 reduction = _str_to_reduction_mode [reduction ]
200229
201230 BLOCK_SIZE_N = triton .next_power_of_2 (min (128 , V ))
202- BLOCK_SIZE_M = get_optimal_block_size (BT , target .element_size (), BLOCK_SIZE_N , is_backward = True )
231+ # grad_output handling:
232+ # - numel() == 1: use scalar grad_output path in kernel.
233+ # - numel() != 1: stream per-element grad_output tile in kernel.
234+ has_grad_output_tile = grad_output .numel () != 1
235+ BLOCK_SIZE_M = get_optimal_block_size (
236+ BT ,
237+ target .element_size (),
238+ BLOCK_SIZE_N ,
239+ is_backward = True ,
240+ needs_grad_output_tile = has_grad_output_tile ,
241+ )
203242 num_cores = get_npu_core_count ()
204243 total_blocks = triton .cdiv (BT , BLOCK_SIZE_M ) * triton .cdiv (V , BLOCK_SIZE_N )
205244 grid = min (num_cores , total_blocks )
206245
207246 _kldiv_kernel_backward [(grid ,)](
208247 target ,
209248 new_grads ,
249+ grad_output ,
210250 BT ,
211251 V ,
212252 BLOCK_SIZE_M = BLOCK_SIZE_M ,
213253 BLOCK_SIZE_N = BLOCK_SIZE_N ,
214254 log_target = log_target ,
255+ reduction = reduction ,
256+ has_grad_output = has_grad_output_tile ,
215257 )
216258
217- # If kl div is the last layer, grad_output is 1.0. Skip the mul then.
218- if torch .equal (grad_output , torch .tensor (1.0 , device = grad_output .device )):
219- derivative = new_grads
220- else :
221- derivative = new_grads * grad_output
222-
223- if reduction == _REDUCTION_MODE_BATCHMEAN .value :
224- derivative = derivative / target .shape [0 ]
225- elif reduction == _REDUCTION_MODE_SUM .value or reduction == _REDUCTION_MODE_NONE .value :
226- pass
227- elif reduction == _REDUCTION_MODE_MEAN .value :
228- derivative = derivative / (target .shape [0 ] * target .shape [1 ])
229-
230- return derivative
259+ return new_grads
231260
232261
233262class LigerKLDivLossFunction (torch .autograd .Function ):
0 commit comments