Skip to content

Commit 6534df0

Browse files
AmdSampsajataylo
authored andcommitted
New WRT configs for autotuning (ROCm#2708)
Reorganized slightly the adding of hard-coded autotuning configs. Fixed wrt1 configs. Added wrt2 & 3 configs. (cherry picked from commit e3e9a17)
1 parent 4433700 commit 6534df0

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,6 +2526,15 @@ def pointwise(
25262526
size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1
25272527
)
25282528
) # 20% improvement
2529+
configs += [
2530+
triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where?
2531+
triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel
2532+
triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1),
2533+
# -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
2534+
# triton_poi_fused_index_put_new_zeros_45
2535+
# triton_poi_fused_index_put_new_zeros_49
2536+
# triton_poi_fused_index_put_new_zeros_54
2537+
]
25292538
if len(size_hints) == 2:
25302539
# Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
25312540
# ROCm has observed improvement by diverging here
@@ -2544,7 +2553,7 @@ def pointwise(
25442553
size_hints, 64, 32
25452554
), # better for some kernels
25462555
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
2547-
triton_config_with_settings(size_hints, 256, 16),
2556+
triton_config_with_settings(size_hints, 256, 16),
25482557
triton_config_with_settings(size_hints, 16, 256),
25492558
triton_config_with_settings(
25502559
size_hints, 128, 16
@@ -2557,6 +2566,17 @@ def pointwise(
25572566
triton_config_with_settings(size_hints, 1, bs),
25582567
*hinted_configs,
25592568
]
2569+
if torch.version.hip:
2570+
configs += [ # add here
2571+
]
2572+
# bypass triton_config_with_settings -> triton_config logic
2573+
if "x" in size_hints and "y" in size_hints:
2574+
configs += [
2575+
Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
2576+
Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
2577+
Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
2578+
]
2579+
25602580
if len(size_hints) == 3:
25612581
if disable_pointwise_autotuning(inductor_meta):
25622582
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
@@ -2763,8 +2783,13 @@ def outer_config_opt():
27632783
[
27642784
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
27652785
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
2766-
]
2767-
)
2786+
make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2787+
make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2788+
make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2789+
make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2790+
make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
2791+
]
2792+
)
27682793

27692794
return result_configs
27702795

0 commit comments

Comments
 (0)