Skip to content

Commit 201758e

Browse files
committed
Added more perf changes for triton kernels
1 parent 5f9e611 commit 201758e

File tree

3 files changed

+91
-39
lines changed

3 files changed

+91
-39
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3296,7 +3296,7 @@ def codegen_body(self):
32963296
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
32973297
)
32983298
self.body.writeline(
3299-
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
3299+
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):"
33003300
)
33013301
with self.body.indent(offset=level + 1):
33023302
self.iteration_ranges_codegen_header(tree, self.body)

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,7 @@ class triton:
12611261
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
12621262
# Raise the threshold to 16 to be safe.
12631263
# We should revisit this once we understand more of the source of register spills.
1264-
spill_threshold: int = 16
1264+
spill_threshold: int = 32
12651265

12661266
# Generate code containing the newer tl.make_block_ptr() API for loads/store
12671267
use_block_ptr = False

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 89 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
765765
# for some (complicated) custom Triton kernels, a register-spilling
766766
# config may yield the best latency.
767767
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
768-
"spill_threshold", 16
768+
"spill_threshold", 32
769769
):
770770
log.debug(
771771
"Skip config %s because of register spilling: %d",
@@ -2198,7 +2198,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
21982198

21992199

22002200
def triton_config_tiled_reduction(
2201-
size_hints, x, y, r, num_stages=1, register_intensive=False
2201+
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
22022202
):
22032203
"""
22042204
Construct a tile reduction triton config with some adjustment
@@ -2235,7 +2235,13 @@ def total_numel() -> int:
22352235
)
22362236
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
22372237
check_max_block(cfg)
2238-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2238+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
2239+
2240+
if torch.version.hip:
2241+
if waves_per_eu is not None:
2242+
config.kwargs["waves_per_eu"] = waves_per_eu
2243+
2244+
return config
22392245

22402246

22412247
def pointwise(
@@ -2279,17 +2285,26 @@ def pointwise(
22792285
triton_config_with_settings(
22802286
size_hints, bs // 2, num_elements_per_warp=64
22812287
),
2282-
# triton_config_with_settings(
2283-
# size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2284-
# ),
22852288
triton_config_with_settings(
22862289
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
22872290
),
2291+
triton_config_with_settings(
2292+
size_hints, 4096 # wrt: better than the max_block for some kernel
2293+
),
22882294
*hinted_configs,
22892295
]
2296+
# Additional reduction configs appended for ROCm builds
2297+
if torch.version.hip:
2298+
configs.append(triton_config_with_settings(
2299+
size_hints,
2300+
2048,
2301+
num_warps=8,
2302+
num_stages=2,
2303+
waves_per_eu=1
2304+
)) # 20% improvement
22902305
if len(size_hints) == 2:
22912306
if (
2292-
disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
2307+
disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE
22932308
) and not (
22942309
inductor_meta.get("max_autotune")
22952310
or inductor_meta.get("max_autotune_pointwise")
@@ -2298,9 +2313,13 @@ def pointwise(
22982313
else:
22992314
configs = [
23002315
triton_config_with_settings(size_hints, 32, 32),
2316+
triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels
23012317
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
23022318
triton_config_with_settings(size_hints, 256, 16),
23032319
triton_config_with_settings(size_hints, 16, 256),
2320+
triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels
2321+
triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more
2322+
triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels
23042323
triton_config_with_settings(size_hints, bs, 1),
23052324
triton_config_with_settings(size_hints, 1, bs),
23062325
*hinted_configs,
@@ -2340,6 +2359,12 @@ def _reduction_configs(
23402359
# Convert reductions to 1D, to simplify heuristics.
23412360
rnumel = get_total_reduction_numel(size_hints)
23422361

2362+
# Is max autotune enabled
2363+
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
2364+
inductor_meta.get("max_autotune")
2365+
or inductor_meta.get("max_autotune_pointwise")
2366+
)
2367+
23432368
register_intensive = False
23442369
MAX_R0_BLOCK = 2048
23452370
if (
@@ -2362,7 +2387,7 @@ def _reduction_configs(
23622387
MAX_R0_BLOCK = 1024
23632388
register_intensive = True
23642389

2365-
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
2390+
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, waves_per_eu=None):
23662391
# For 3D case with tiling scores, create an adapted version
23672392
if "y" in size_hints:
23682393
assert "tiling_scores" in inductor_meta
@@ -2374,6 +2399,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
23742399
num_warps=num_warps,
23752400
num_stages=num_stages,
23762401
register_intensive=register_intensive,
2402+
waves_per_eu=waves_per_eu
23772403
)
23782404
else:
23792405
# For other cases, use the original function
@@ -2384,6 +2410,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
23842410
num_warps=num_warps,
23852411
num_stages=num_stages,
23862412
register_intensive=register_intensive,
2413+
waves_per_eu=waves_per_eu
23872414
)
23882415

23892416
contiguous_config = make_config(
@@ -2397,32 +2424,39 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
23972424
min(rnumel, MAX_R0_BLOCK),
23982425
register_intensive=register_intensive,
23992426
)
2400-
# For 3d tiling, default to more autotuning initially
2401-
if "y" in size_hints:
2402-
pass
2403-
elif inductor_meta.get("max_autotune") or inductor_meta.get(
2404-
"max_autotune_pointwise"
2405-
):
2406-
pass # skip all these cases
2407-
elif reduction_hint == ReductionHint.INNER:
2408-
return [contiguous_config]
2409-
elif reduction_hint == ReductionHint.OUTER:
2410-
return [outer_config]
2411-
elif reduction_hint == ReductionHint.OUTER_TINY:
2412-
return [tiny_config]
2413-
if disable_pointwise_autotuning(inductor_meta):
2414-
return [make_config(32, 128)]
2415-
return [
2416-
contiguous_config,
2417-
outer_config,
2418-
tiny_config,
2419-
make_config(64, 64),
2420-
make_config(8, 512),
2421-
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2422-
# TODO: this may only be beneficial when each iteration of the reduction
2423-
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2424-
make_config(64, 4, num_warps=8),
2425-
]
2427+
2428+
result_configs = []
2429+
2430+
if not (max_autotune_enabled or "y" in size_hints):
2431+
if reduction_hint == ReductionHint.INNER:
2432+
result_configs = [contiguous_config]
2433+
elif reduction_hint == ReductionHint.OUTER:
2434+
result_configs = [outer_config]
2435+
elif reduction_hint == ReductionHint.OUTER_TINY:
2436+
result_configs = [tiny_config]
2437+
else:
2438+
result_configs = [make_config(32, 128)]
2439+
else:
2440+
result_configs = [
2441+
contiguous_config,
2442+
outer_config,
2443+
tiny_config,
2444+
make_config(64, 64),
2445+
make_config(8, 512),
2446+
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2447+
# TODO: this may only be beneficial when each iteration of the reduction
2448+
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2449+
make_config(64, 4, num_warps=8),
2450+
]
2451+
2452+
# Add ROCm-specific configs when autotuning
2453+
if torch.version.hip:
2454+
result_configs.extend([
2455+
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
2456+
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1)
2457+
])
2458+
2459+
return result_configs
24262460

24272461

24282462
def match_target_block_product(
@@ -2480,6 +2514,7 @@ def adapt_config_for_tiling(
24802514
num_stages=1,
24812515
register_intensive=False,
24822516
persistent_reduction=False,
2517+
waves_per_eu=None
24832518
) -> Config:
24842519
"""
24852520
Create an adapted configuration based on tiling scores,
@@ -2498,6 +2533,7 @@ def adapt_config_for_tiling(
24982533
block_sizes["r0_"],
24992534
num_stages=num_stages,
25002535
register_intensive=register_intensive,
2536+
waves_per_eu=waves_per_eu
25012537
)
25022538

25032539

@@ -2608,15 +2644,15 @@ def _persistent_reduction_configs(
26082644
if "y" not in size_hints:
26092645
configs = [
26102646
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
2611-
for xblock in (1, 8, 32, 128)
2647+
for xblock in (1, 4, 8, 16, 32, 64, 128, 256, 512)
26122648
if xblock == 1
2613-
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
2649+
or (xblock <= xnumel and rnumel * xblock <= 4096)
26142650
]
26152651
else:
26162652
configs = []
26172653
assert "tiling_scores" in inductor_meta
26182654
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
2619-
for target_block_size in (1, 8, 32, 64, 128):
2655+
for target_block_size in (1, 4, 8, 16, 32, 64, 128, 256, 512):
26202656
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
26212657
continue
26222658

@@ -2651,6 +2687,22 @@ def _persistent_reduction_configs(
26512687
for conf in tiny_configs:
26522688
if conf not in configs:
26532689
configs.append(conf)
2690+
2691+
# Expand configs to try additional warps
2692+
expanded_configs = []
2693+
for conf in configs:
2694+
num_warps = conf.num_warps
2695+
max_warps = 8 if torch.version.hip else 16
2696+
small_conf = copy.deepcopy(conf)
2697+
large_conf = copy.deepcopy(conf)
2698+
small_conf.num_warps = max(small_conf.num_warps // 2, 1)
2699+
large_conf.num_warps = min(large_conf.num_warps * 2, max_warps)
2700+
expanded_configs.append(conf)
2701+
expanded_configs.append(small_conf)
2702+
expanded_configs.append(large_conf)
2703+
2704+
configs = expanded_configs
2705+
26542706
elif reduction_hint == ReductionHint.OUTER_TINY:
26552707
configs = tiny_configs
26562708

0 commit comments

Comments
 (0)