Skip to content

Commit f9dd427

Browse files
committed
[NPU]Optimize KLDiv backward kernel performance
1 parent 5c9cecc commit f9dd427

File tree

1 file changed

+47
-18
lines changed
  • src/liger_kernel/ops/backends/_ascend/ops

1 file changed

+47
-18
lines changed

src/liger_kernel/ops/backends/_ascend/ops/kl_div.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,14 @@ def _kldiv_kernel_forward(
8484
def _kldiv_kernel_backward(
8585
target_ptr,
8686
new_grads_ptr,
87+
grad_output_ptr,
8788
n_rows,
8889
n_cols,
8990
BLOCK_SIZE_M: tl.constexpr,
9091
BLOCK_SIZE_N: tl.constexpr,
9192
log_target: tl.constexpr = False,
93+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
94+
has_grad_output: tl.constexpr = False,
9295
):
9396
pid = tl.program_id(0)
9497
num_progs = tl.num_programs(0)
@@ -97,6 +100,10 @@ def _kldiv_kernel_backward(
97100
grid_n = tl.cdiv(n_cols, BLOCK_SIZE_N)
98101
total_2d_blocks = grid_m * grid_n
99102

103+
# For reduced losses, grad_output is a scalar. Load it once per program.
104+
if not has_grad_output:
105+
grad_output_scalar = tl.load(grad_output_ptr)
106+
100107
# Persistent-program loop over logical 2D blocks.
101108
for block_idx in tl.range(pid, total_2d_blocks, num_progs):
102109
block_m = block_idx // grid_n
@@ -118,6 +125,17 @@ def _kldiv_kernel_backward(
118125
else:
119126
res = y_true * -1
120127

128+
if not has_grad_output:
129+
res = res * grad_output_scalar
130+
else:
131+
grad_output = tl.load(grad_output_ptr + offset, mask=mask, other=0.0)
132+
res = res * grad_output
133+
134+
if reduction == _REDUCTION_MODE_BATCHMEAN:
135+
res = res / n_rows
136+
elif reduction == _REDUCTION_MODE_MEAN:
137+
res = res / (n_rows * n_cols)
138+
121139
tl.store(new_grads_ptr + offset, res, mask=mask)
122140

123141

@@ -126,13 +144,24 @@ def _kldiv_kernel_backward(
126144
# -----------------------------------------------------------------------------
127145

128146

129-
def get_optimal_block_size(n_rows, dtype_size, BLOCK_SIZE_N: tl.constexpr, is_backward: bool = False):
147+
def get_optimal_block_size(
148+
n_rows,
149+
dtype_size,
150+
BLOCK_SIZE_N: tl.constexpr,
151+
is_backward: bool = False,
152+
needs_grad_output_tile: bool = False,
153+
):
130154
"""
131155
Calculate optimal BLOCK_SIZE_M using compute_default_tiling_strategy.
132156
"""
133157
# 1) Set memory multiplier
134-
# Backward is lighter than forward in this op, so use a smaller multiplier.
135-
multiplier = 2.5 if is_backward else 3.0
158+
# Backward is lighter than forward in this op, so we typically use a smaller multiplier.
159+
# If backward also needs to stream a full grad_output tile (i.e., grad_output is not a scalar),
160+
# its memory footprint becomes closer to forward, so we bump the multiplier.
161+
if is_backward:
162+
multiplier = 3.0 if needs_grad_output_tile else 2.5
163+
else:
164+
multiplier = 3.0
136165

137166
# For bf16/fp16 (dtype_size < 4), compile-time UB overflow was observed on some shapes.
138167
# Clamp to fp32 size for a conservative tiling estimate; this can be refined later.
@@ -199,35 +228,35 @@ def kldiv_backward_triton(target, grad_output, new_grads, log_target, reduction)
199228
reduction = _str_to_reduction_mode[reduction]
200229

201230
BLOCK_SIZE_N = triton.next_power_of_2(min(128, V))
202-
BLOCK_SIZE_M = get_optimal_block_size(BT, target.element_size(), BLOCK_SIZE_N, is_backward=True)
231+
# grad_output handling:
232+
# - numel() == 1: use scalar grad_output path in kernel.
233+
# - numel() != 1: stream per-element grad_output tile in kernel.
234+
has_grad_output_tile = grad_output.numel() != 1
235+
BLOCK_SIZE_M = get_optimal_block_size(
236+
BT,
237+
target.element_size(),
238+
BLOCK_SIZE_N,
239+
is_backward=True,
240+
needs_grad_output_tile=has_grad_output_tile,
241+
)
203242
num_cores = get_npu_core_count()
204243
total_blocks = triton.cdiv(BT, BLOCK_SIZE_M) * triton.cdiv(V, BLOCK_SIZE_N)
205244
grid = min(num_cores, total_blocks)
206245

207246
_kldiv_kernel_backward[(grid,)](
208247
target,
209248
new_grads,
249+
grad_output,
210250
BT,
211251
V,
212252
BLOCK_SIZE_M=BLOCK_SIZE_M,
213253
BLOCK_SIZE_N=BLOCK_SIZE_N,
214254
log_target=log_target,
255+
reduction=reduction,
256+
has_grad_output=has_grad_output_tile,
215257
)
216258

217-
# If kl div is the last layer, grad_output is 1.0. Skip the mul then.
218-
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
219-
derivative = new_grads
220-
else:
221-
derivative = new_grads * grad_output
222-
223-
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
224-
derivative = derivative / target.shape[0]
225-
elif reduction == _REDUCTION_MODE_SUM.value or reduction == _REDUCTION_MODE_NONE.value:
226-
pass
227-
elif reduction == _REDUCTION_MODE_MEAN.value:
228-
derivative = derivative / (target.shape[0] * target.shape[1])
229-
230-
return derivative
259+
return new_grads
231260

232261

233262
class LigerKLDivLossFunction(torch.autograd.Function):

0 commit comments

Comments
 (0)