Skip to content

Commit 2c15edb

Browse files
authored
gather_gemv: fix eager impl to support large input; fix compile impl (#455)
1 parent f70c7c3 commit 2c15edb

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

tritonbench/operators/gather_gemv/operator.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,25 @@ def test_0(self, p1, p2, p3) -> Callable:
4444

4545
@register_benchmark(baseline=True)
4646
def test_eager(self, w, idx, x):
47-
return lambda: w[idx].to(x.dtype) @ x
47+
s = x.size(0)
48+
49+
if s <= 8192:
50+
return lambda: w[idx].to(x.dtype) @ x
51+
52+
# For very large matrices (e.g. S=16384) the batched advanced indexing
53+
# path above launches a CUDA kernel with an invalid configuration.
54+
# Fall back to per-expert slicing which is slower but robust.
55+
def eager_impl():
56+
outputs = []
57+
for idx_val in idx.tolist():
58+
outputs.append(w[idx_val].to(x.dtype) @ x)
59+
return torch.stack(outputs, dim=0)
60+
61+
return eager_impl
4862

4963
@register_benchmark()
5064
def test_inductor(self, w, idx, x):
51-
@torch.compile
65+
@torch.compile(mode="max-autotune-no-cudagraphs")
5266
def gather_gemv(w, idx, x):
5367
return w[idx].to(x.dtype) @ x
5468

0 commit comments

Comments
 (0)