Skip to content

Commit 47078b5

Browse files
committed
Added triton perf improvement changes
1 parent b65a09a commit 47078b5

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,11 +1083,11 @@ def relu(x):
10831083

10841084
@staticmethod
10851085
def minimum(a, b):
1086-
return f"triton_helpers.minimum({a}, {b})"
1086+
return f"tl.minimum({a}, {b})"
10871087

10881088
@staticmethod
10891089
def maximum(a, b):
1090-
return f"triton_helpers.maximum({a}, {b})"
1090+
return f"tl.maximum({a}, {b})"
10911091

10921092
@staticmethod
10931093
def where(a, b, c):
@@ -1273,7 +1273,7 @@ def load_seed(name, offset):
12731273
@staticmethod
12741274
@maybe_upcast_float32()
12751275
def rsqrt(x):
1276-
return f"libdevice.rsqrt({x})"
1276+
return f"tl.rsqrt({x})"
12771277

12781278
@staticmethod
12791279
@maybe_upcast_float32()

torch/_inductor/runtime/hints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
1414
# NOTE: if these fail asserts submit a PR to increase them
1515
TRITON_MAX_BLOCK = {
16-
"X": 4096,
16+
"X": 8192,
1717
"Y": 1024,
1818
"Z": 1024,
1919
"R0_": 4096 * 16, # * 16 is multi-kernel only

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,9 @@ def triton_config(
20712071
num_stages=1,
20722072
num_elements_per_warp=256,
20732073
min_elem_per_thread=0,
2074+
num_warps=None,
2075+
matrix_instr=None,
2076+
waves_per_eu=None
20742077
) -> Config:
20752078
"""
20762079
Construct a pointwise triton config with some adjustment heuristics
@@ -2127,9 +2130,11 @@ def triton_config(
21272130
):
21282131
z *= 2
21292132

2130-
num_warps = _num_warps(
2131-
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
2132-
)
2133+
# Calculate num_waprs if they are not hard passed to config
2134+
if num_warps == None:
2135+
num_warps = _num_warps(
2136+
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
2137+
)
21332138
# we are going to arrive at 2 warps only if bs was too small due to
21342139
# numel being too small. However to workaround some ptx bugs we still
21352140
# want at least 4 warps if there's enough elements per thread
@@ -2159,7 +2164,15 @@ def triton_config(
21592164
cfg["ZBLOCK"] = z
21602165
check_max_block(cfg)
21612166
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
2162-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2167+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
2168+
2169+
if torch.version.hip:
2170+
if matrix_instr is not None:
2171+
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
2172+
if waves_per_eu is not None:
2173+
config.kwargs["waves_per_eu"] = waves_per_eu
2174+
2175+
return config
21632176

21642177

21652178
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
@@ -2207,6 +2220,7 @@ def triton_config_reduction(
22072220
num_stages=1,
22082221
num_warps=None,
22092222
register_intensive=False,
2223+
waves_per_eu=None
22102224
) -> Config:
22112225
"""
22122226
Construct a reduction triton config with some adjustment heuristics
@@ -2250,7 +2264,13 @@ def total_numel() -> int:
22502264
cfg = _get_config({"x": x, **rnumels})
22512265
check_max_block(cfg)
22522266
check_config(cfg, xnumel=size_hints["x"])
2253-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2267+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
2268+
2269+
if torch.version.hip:
2270+
if waves_per_eu is not None:
2271+
config.kwargs["waves_per_eu"] = waves_per_eu
2272+
2273+
return config
22542274

22552275

22562276
def _get_config(numels: dict[str, int]) -> dict[str, int]:
@@ -2388,6 +2408,12 @@ def pointwise(
23882408
triton_config_with_settings(
23892409
size_hints, bs // 2, num_elements_per_warp=64
23902410
),
2411+
# triton_config_with_settings(
2412+
# size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2413+
# ),
2414+
triton_config_with_settings(
2415+
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2416+
),
23912417
*hinted_configs,
23922418
]
23932419
if len(size_hints) == 2:
@@ -2624,6 +2650,7 @@ def reduction(
26242650

26252651
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
26262652
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
2653+
26272654
return cached_autotune(
26282655
size_hints,
26292656
configs=configs,

0 commit comments

Comments
 (0)