Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,25 @@ def _topk_backward(
offs_xn = tl.arange(0, N_EXPTS_PAD)
offs_yn = tl.arange(0, N_EXPTS_ACT)
mask_xn = offs_xn < n_expts_tot
# recompute softmax
y_indx = tl.load(Yi + offs_yn)
# recompute softmax
x = tl.load(X + y_indx)
x = x.to(tl.float32)
y = tl.softmax(x)
# compute input-gradient
dy = tl.load(DY + offs_yn)
dy = dy.to(tl.float32)
s = tl.sum(y * dy, 0)
# write-back input gradient
tl.store(DX + offs_xn, 0, mask=mask_xn)
tl.debug_barrier()
if APPLY_SOFTMAX:
dx = y * (dy - s)
dx_topk = y * (dy - s)
else:
dx = dy
tl.store(DX + y_indx, dx)
dx_topk = dy
# full gradient using vectorized operations
dx_full = tl.zeros([N_EXPTS_PAD], dtype=tl.float32)
offs_xn_expanded = offs_xn[:, None]
y_indx_expanded = y_indx[None, :]
match_mask = (offs_xn_expanded == y_indx_expanded)
dx_topk_expanded = dx_topk[None, :]
dx_full = tl.sum(tl.where(match_mask, dx_topk_expanded, 0.0), axis=1)
Comment on lines +50 to +54
Copy link

Copilot AI Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vectorized implementation creates large intermediate tensors through broadcasting. For large N_EXPTS_PAD and N_EXPTS_ACT, this creates an N_EXPTS_PAD × N_EXPTS_ACT matrix which may consume significant memory and be slower than the original scatter approach for sparse updates.

Suggested change
offs_xn_expanded = offs_xn[:, None]
y_indx_expanded = y_indx[None, :]
match_mask = (offs_xn_expanded == y_indx_expanded)
dx_topk_expanded = dx_topk[None, :]
dx_full = tl.sum(tl.where(match_mask, dx_topk_expanded, 0.0), axis=1)
# Scatter dx_topk into dx_full at positions y_indx
dx_full = tl.zeros([N_EXPTS_PAD], dtype=tl.float32)
for i in range(N_EXPTS_ACT):
idx = y_indx[i]
if idx < N_EXPTS_PAD:
dx_full = tl.store(dx_full, idx, dx_topk[i])

Copilot uses AI. Check for mistakes.
# write back
tl.store(DX + offs_xn, dx_full, mask=mask_xn)
Loading