From f7f4e802f7d1e834ca504183d84b1041ac4273b5 Mon Sep 17 00:00:00 2001 From: dev-tomek Date: Mon, 13 Oct 2025 12:38:43 +0000 Subject: [PATCH] topk backward vectorized --- .../topk_details/_topk_backward.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py b/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py index eebe481771..3027d1b5f1 100644 --- a/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py +++ b/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py @@ -32,8 +32,8 @@ def _topk_backward( offs_xn = tl.arange(0, N_EXPTS_PAD) offs_yn = tl.arange(0, N_EXPTS_ACT) mask_xn = offs_xn < n_expts_tot - # recompute softmax y_indx = tl.load(Yi + offs_yn) + # recompute softmax x = tl.load(X + y_indx) x = x.to(tl.float32) y = tl.softmax(x) @@ -41,11 +41,16 @@ def _topk_backward( dy = tl.load(DY + offs_yn) dy = dy.to(tl.float32) s = tl.sum(y * dy, 0) - # write-back input gradient - tl.store(DX + offs_xn, 0, mask=mask_xn) - tl.debug_barrier() if APPLY_SOFTMAX: - dx = y * (dy - s) + dx_topk = y * (dy - s) else: - dx = dy - tl.store(DX + y_indx, dx) + dx_topk = dy + # full gradient using vectorized operations + dx_full = tl.zeros([N_EXPTS_PAD], dtype=tl.float32) + offs_xn_expanded = offs_xn[:, None] + y_indx_expanded = y_indx[None, :] + match_mask = (offs_xn_expanded == y_indx_expanded) + dx_topk_expanded = dx_topk[None, :] + dx_full = tl.sum(tl.where(match_mask, dx_topk_expanded, 0.0), axis=1) + # write back + tl.store(DX + offs_xn, dx_full, mask=mask_xn)