Skip to content

Conversation

@dev-tomek
Copy link
Contributor

No description provided.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the topk backward pass implementation by vectorizing gradient accumulation operations. The changes replace the previous approach of clearing and then selectively writing gradients with a more efficient vectorized method.

  • Replaces manual gradient clearing and selective writes with vectorized operations
  • Uses broadcasting and masking to accumulate gradients efficiently
  • Maintains the same computational logic while improving performance

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +50 to +54
offs_xn_expanded = offs_xn[:, None]
y_indx_expanded = y_indx[None, :]
match_mask = (offs_xn_expanded == y_indx_expanded)
dx_topk_expanded = dx_topk[None, :]
dx_full = tl.sum(tl.where(match_mask, dx_topk_expanded, 0.0), axis=1)
Copy link

Copilot AI Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vectorized implementation creates large intermediate tensors through broadcasting. For large N_EXPTS_PAD and N_EXPTS_ACT, this creates an N_EXPTS_PAD × N_EXPTS_ACT matrix which may consume significant memory and be slower than the original scatter approach for sparse updates.

Suggested change
offs_xn_expanded = offs_xn[:, None]
y_indx_expanded = y_indx[None, :]
match_mask = (offs_xn_expanded == y_indx_expanded)
dx_topk_expanded = dx_topk[None, :]
dx_full = tl.sum(tl.where(match_mask, dx_topk_expanded, 0.0), axis=1)
# Scatter dx_topk into dx_full at positions y_indx
dx_full = tl.zeros([N_EXPTS_PAD], dtype=tl.float32)
for i in range(N_EXPTS_ACT):
idx = y_indx[i]
if idx < N_EXPTS_PAD:
dx_full = tl.store(dx_full, idx, dx_topk[i])

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Some test_routing.py::test_op tests cases fail on BMG

1 participant