File tree Expand file tree Collapse file tree 1 file changed +11
-10
lines changed
tritonbench/operators/grouped_gemm Expand file tree Collapse file tree 1 file changed +11
-10
lines changed Original file line number Diff line number Diff line change @@ -108,17 +108,18 @@ def _inner():
108
108
# TODO: Does not work on hip
109
109
@register_benchmark (enabled = is_cuda ())
110
110
def preprocessed_pt2_triton_grouped_mm (self , group_A , group_B ):
111
- def _inner ():
112
- torch ._dynamo .reset ()
111
+ torch ._dynamo .reset ()
113
112
114
- with inductor_config .patch (
115
- max_autotune = True ,
116
- max_autotune_gemm_backends = "TRITON" ,
117
- autotune_fallback_to_aten = False ,
118
- ):
119
- A_packed , B_shared , offs = self .list_input_to_jagged (group_A , group_B )
120
- compiled = torch .compile (torch ._grouped_mm , dynamic = False )
121
- return compiled (A_packed , B_shared , offs = offs , bias = None )
113
+ with inductor_config .patch (
114
+ max_autotune = True ,
115
+ max_autotune_gemm_backends = "TRITON" ,
116
+ autotune_fallback_to_aten = False ,
117
+ ):
118
+ A_packed , B_shared , offs = self .list_input_to_jagged (group_A , group_B )
119
+ compiled = torch .compile (torch ._grouped_mm , dynamic = False )
120
+
121
+ def _inner ():
122
+ return compiled (A_packed , B_shared , offs = offs , bias = None )
122
123
123
124
return _inner
124
125
You can’t perform that action at this time.
0 commit comments