Skip to content

Commit b19458a

Browse files
committed
triton sanity check for 2D POI
1 parent 6231607 commit b19458a

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,17 +2568,19 @@ def pointwise(
25682568
*hinted_configs,
25692569
]
25702570
if torch.version.hip:
2571-
configs += [ # add here
2572-
]
2573-
# bypass triton_config_with_settings -> triton_config logic
25742571
if "x" in size_hints and "y" in size_hints:
2575-
configs += [
2576-
Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
2577-
Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
2578-
Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
2579-
Config({"XBLOCK":64, "YBLOCK": 256}, num_warps=4), # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19
2580-
Config({"XBLOCK":512, "YBLOCK": 64}, num_warps=8), # wri0: 58us: triton_poi_fused_clone_53
2581-
]
2572+
"""add 2D tiling configs, but don't use triton_config_with_settings function
2573+
as it is buggy and might change the tiling randomly
2574+
"""
2575+
def addConfig__(xblock:int, yblock:int, num_warps:int):
2576+
# only add a tiling config if size is bigger than the tile
2577+
if size_hints["x"] >= xblock and size_hints["y"] >= yblock:
2578+
configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps))
2579+
addConfig__(512, 8, 8) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
2580+
addConfig__(32, 128, 4) # wrt2: 570us : triton_poi_fused_add_transpose_view_52
2581+
addConfig__(64, 32, 8) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
2582+
addConfig__(64, 256, 4) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19
2583+
addConfig__(512, 64, 8) # wri0: 58us: triton_poi_fused_clone_53
25822584

25832585
if len(size_hints) == 3:
25842586
if disable_pointwise_autotuning(inductor_meta):

0 commit comments

Comments
 (0)