Skip to content

Commit 349eee9

Browse files
committed
fixed wrt1 configs, added wrt2&3 configs
1 parent 2746590 commit 349eee9

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,15 +2415,18 @@ def pointwise(
24152415
triton_config_with_settings(
24162416
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
24172417
),
2418-
triton_config_with_settings(
2419-
size_hints, 4096 # wrt: better than the max_block for some kernel
2420-
),
24212418
*hinted_configs,
24222419
]
24232420
# Additional reduction configs appended for ROCm builds
24242421
if torch.version.hip:
24252422
configs += [
2426-
triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1) # 20% improvement # ..where?
2423+
triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where?
2424+
triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel
2425+
triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1),
2426+
# -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
2427+
# triton_poi_fused_index_put_new_zeros_45
2428+
# triton_poi_fused_index_put_new_zeros_49
2429+
# triton_poi_fused_index_put_new_zeros_54
24272430
]
24282431
if len(size_hints) == 2:
24292432
if (
@@ -2437,7 +2440,7 @@ def pointwise(
24372440
configs = [
24382441
triton_config_with_settings(size_hints, 32, 32),
24392442
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
2440-
triton_config_with_settings(size_hints, 256, 16),
2443+
triton_config_with_settings(size_hints, 256, 16),
24412444
triton_config_with_settings(size_hints, 16, 256),
24422445
triton_config_with_settings(size_hints, bs, 1),
24432446
triton_config_with_settings(size_hints, 1, bs),
@@ -2448,8 +2451,8 @@ def pointwise(
24482451
]
24492452
# bypass triton_config_with_settings -> triton_config logic
24502453
if "x" in size_hints and "y" in size_hints:
2451-
cfg = {"XBLOCK": 32, "YBLOCK": 128}
24522454
configs += [
2455+
Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
24532456
Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
24542457
Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
24552458
]
@@ -2474,12 +2477,6 @@ def pointwise(
24742477

24752478
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
24762479

2477-
print()
2478-
print("Pointwise will use following configs")
2479-
for config in configs:
2480-
print(">", config)
2481-
print()
2482-
24832480
return cached_autotune(
24842481
size_hints,
24852482
configs,
@@ -2699,12 +2696,6 @@ def reduction(
26992696
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
27002697
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
27012698

2702-
print()
2703-
print("Reduction will use following configs")
2704-
for config in configs:
2705-
print(">", config)
2706-
print()
2707-
27082699
return cached_autotune(
27092700
size_hints,
27102701
configs=configs,

0 commit comments

Comments
 (0)