Skip to content

Commit 451c2b4

Browse files
AmdSampsanaromero77amd
authored andcommitted
New WRT configs for autotuning (#2708)
Reorganized slightly the adding of hard-coded autotuning configs. Fixed wrt1 configs. Added wrt2 & 3 configs. (cherry picked from commit e3e9a17) (cherry picked from commit 6534df0)
1 parent 768ab84 commit 451c2b4

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,6 +1945,15 @@ def pointwise(
19451945
size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1
19461946
)
19471947
) # 20% improvement
1948+
configs += [
1949+
triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where?
1950+
triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel
1951+
triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1),
1952+
# -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
1953+
# triton_poi_fused_index_put_new_zeros_45
1954+
# triton_poi_fused_index_put_new_zeros_49
1955+
# triton_poi_fused_index_put_new_zeros_54
1956+
]
19481957
if len(size_hints) == 2:
19491958
# Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
19501959
# ROCm has observed improvement by diverging here
@@ -1963,7 +1972,7 @@ def pointwise(
19631972
size_hints, 64, 32
19641973
), # better for some kernels
19651974
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
1966-
triton_config_with_settings(size_hints, 256, 16),
1975+
triton_config_with_settings(size_hints, 256, 16),
19671976
triton_config_with_settings(size_hints, 16, 256),
19681977
triton_config_with_settings(
19691978
size_hints, 128, 16
@@ -1976,6 +1985,17 @@ def pointwise(
19761985
triton_config_with_settings(size_hints, 1, bs),
19771986
*hinted_configs,
19781987
]
1988+
if torch.version.hip:
1989+
configs += [ # add here
1990+
]
1991+
# bypass triton_config_with_settings -> triton_config logic
1992+
if "x" in size_hints and "y" in size_hints:
1993+
configs += [
1994+
Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
1995+
Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
1996+
Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
1997+
]
1998+
19791999
if len(size_hints) == 3:
19802000
if disable_pointwise_autotuning(inductor_meta):
19812001
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
@@ -2188,8 +2208,13 @@ def outer_config_opt():
21882208
[
21892209
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
21902210
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
2191-
]
2192-
)
2211+
make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2212+
make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2213+
make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2214+
make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2215+
make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
2216+
]
2217+
)
21932218

21942219
return result_configs
21952220

0 commit comments

Comments
 (0)