Skip to content

Commit 39ca6ce

Browse files
Do the "preprocessing" right for PyTorch compiled grouped GEMM (#513)
1 parent 379a315 commit 39ca6ce

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tritonbench/operators/grouped_gemm/operator.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,18 @@ def _inner():
108108
# TODO: Does not work on hip
109109
@register_benchmark(enabled=is_cuda())
110110
def preprocessed_pt2_triton_grouped_mm(self, group_A, group_B):
111-
def _inner():
112-
torch._dynamo.reset()
111+
torch._dynamo.reset()
113112

114-
with inductor_config.patch(
115-
max_autotune=True,
116-
max_autotune_gemm_backends="TRITON",
117-
autotune_fallback_to_aten=False,
118-
):
119-
A_packed, B_shared, offs = self.list_input_to_jagged(group_A, group_B)
120-
compiled = torch.compile(torch._grouped_mm, dynamic=False)
121-
return compiled(A_packed, B_shared, offs=offs, bias=None)
113+
with inductor_config.patch(
114+
max_autotune=True,
115+
max_autotune_gemm_backends="TRITON",
116+
autotune_fallback_to_aten=False,
117+
):
118+
A_packed, B_shared, offs = self.list_input_to_jagged(group_A, group_B)
119+
compiled = torch.compile(torch._grouped_mm, dynamic=False)
120+
121+
def _inner():
122+
return compiled(A_packed, B_shared, offs=offs, bias=None)
122123

123124
return _inner
124125

0 commit comments

Comments
 (0)