Skip to content

Commit e3e9a17

Browse files
authored
New WRT configs for autotuning (#2708)
Reorganized slightly the adding of hard-coded autotuning configs. Fixed wrt1 configs. Added wrt2 & 3 configs.
1 parent b76ce7a commit e3e9a17

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,20 +2415,19 @@ def pointwise(
24152415
triton_config_with_settings(
24162416
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
24172417
),
2418-
triton_config_with_settings(
2419-
size_hints, 4096 # wrt: better than the max_block for some kernel
2420-
),
24212418
*hinted_configs,
24222419
]
24232420
# Additional reduction configs appended for ROCm builds
24242421
if torch.version.hip:
2425-
configs.append(triton_config_with_settings(
2426-
size_hints,
2427-
2048,
2428-
num_warps=8,
2429-
num_stages=2,
2430-
waves_per_eu=1
2431-
)) # 20% improvement
2422+
configs += [
2423+
triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where?
2424+
triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel
2425+
triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1),
2426+
# -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
2427+
# triton_poi_fused_index_put_new_zeros_45
2428+
# triton_poi_fused_index_put_new_zeros_49
2429+
# triton_poi_fused_index_put_new_zeros_54
2430+
]
24322431
if len(size_hints) == 2:
24332432
if (
24342433
disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE
@@ -2440,17 +2439,24 @@ def pointwise(
24402439
else:
24412440
configs = [
24422441
triton_config_with_settings(size_hints, 32, 32),
2443-
triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels
24442442
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
2445-
triton_config_with_settings(size_hints, 256, 16),
2443+
triton_config_with_settings(size_hints, 256, 16),
24462444
triton_config_with_settings(size_hints, 16, 256),
2447-
triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels
2448-
triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more
2449-
triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels
24502445
triton_config_with_settings(size_hints, bs, 1),
24512446
triton_config_with_settings(size_hints, 1, bs),
24522447
*hinted_configs,
24532448
]
2449+
if torch.version.hip:
2450+
configs += [ # add here
2451+
]
2452+
# bypass triton_config_with_settings -> triton_config logic
2453+
if "x" in size_hints and "y" in size_hints:
2454+
configs += [
2455+
Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
2456+
Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
2457+
Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
2458+
]
2459+
24542460
if len(size_hints) == 3:
24552461
if disable_pointwise_autotuning(inductor_meta):
24562462
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
@@ -2583,9 +2589,14 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, wa
25832589
if torch.version.hip:
25842590
result_configs.extend([
25852591
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
2586-
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1)
2592+
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
2593+
make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2594+
make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2595+
make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2596+
make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2597+
make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
25872598
])
2588-
2599+
25892600
return result_configs
25902601

25912602

0 commit comments

Comments
 (0)