From 24717c36daa538ed80443c37f64bbab71a6c8a41 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Wed, 3 Sep 2025 04:46:13 -0500 Subject: [PATCH 01/28] Naive foreach autotune support (cherry picked from commit 5d4455f6bbdf8625179acffe69e6c127f6c6f12b) (cherry picked from commit d3d77f57be45d73b1f9d9e8b581fcbddcfc77e3c) --- torch/_inductor/runtime/triton_heuristics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 4d87a8236c460..cbd1a6b1d8d9d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2275,6 +2275,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): Compile a triton foreach kernel """ configs = [] + if disable_pointwise_autotuning(inductor_meta) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") From ec5765e4b365315fb3ccd4ca6aeace2f1dcb3f3d Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Thu, 4 Sep 2025 14:48:07 +0100 Subject: [PATCH 02/28] Update triton_heuristics.py (cherry picked from commit 2fc752529e233ff4cd496f77348da1e0831e4f98) (cherry picked from commit 528cf0206b32377dbd62875a4a73ad9099b3e8e2) --- torch/_inductor/runtime/triton_heuristics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index cbd1a6b1d8d9d..3d66c6b67cc15 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2276,6 +2276,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): """ configs = [] + # Naive autotuning path for num_warps if disable_pointwise_autotuning(inductor_meta) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") From a49093ead7cabc50d81532fd468d50b16753865f Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:39:00 +0100 Subject: [PATCH 03/28] Update triton_heuristics.py (cherry picked from commit d5c71f01df6c09c445f8951a4fa42b8710e14fc6) (cherry picked from commit 11e1dfcb72887d9b8961645a51f03a52319341b7) --- torch/_inductor/runtime/triton_heuristics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 3d66c6b67cc15..a209d7c6eb886 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2278,8 +2278,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): # Naive autotuning path for num_warps if disable_pointwise_autotuning(inductor_meta) and not ( - inductor_meta.get("max_autotune") or - inductor_meta.get("max_autotune_pointwise") + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") ): configs.append(triton.Config({}, num_stages=1, num_warps=8)) else: From 1363a1c24d5580530940ea4fb8795cffc4f93a72 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:39:25 +0100 Subject: [PATCH 04/28] Linting (cherry picked from commit 262a33e07e702582d731b838961b419076a693e5) (cherry picked from commit 0cf1c8953c28e0fc9326270ac054864d6ee1f79f) --- torch/_inductor/runtime/triton_heuristics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a209d7c6eb886..db1c88c53365a 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2294,6 +2294,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): filename=filename, ) + @dataclasses.dataclass class GridExpr: """Generate code for grid size expressions in launcher""" From f80f0b8f826a00fe31ffb3bbc9eaaaf9645422d8 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 14 Nov 2025 23:04:04 +0000 Subject: [PATCH 05/28] Conditionalize xblock values for HIP. --- torch/_inductor/runtime/triton_heuristics.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index db1c88c53365a..587b7af1de7a3 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2097,9 +2097,14 @@ def _persistent_reduction_configs( or inductor_meta.get("max_autotune_pointwise") ) + if torch.version.hip: + xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256] + else: + xblock_vals = [1, 8, 32, 128] + configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) + for xblock in xblock_vals if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) ] From 5ed2f823e40a35c8233056fb3864fc534674b510 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 25 Sep 2025 23:29:35 +0000 Subject: [PATCH 06/28] Create tiny_config. (cherry picked from commit 9f19754839d7d2072b73de3953e456ab4eeefaae) (cherry picked from commit dee2fdf2eca9508020d079a20cf5b0a0bc0fd389) --- torch/_inductor/runtime/triton_heuristics.py | 36 +++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 587b7af1de7a3..6d0383ebdd995 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2108,27 +2108,31 @@ def _persistent_reduction_configs( if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) ] + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + + # defer to more autotuning, initially + if "y" in size_hints: + pass # TODO(jansel): we should be able to improve these heuristics - if not max_autotune_enabled: # Don't filter if tuning enabled + elif not max_autotune_enabled: # Don't filter if tuning enabled if reduction_hint == ReductionHint.INNER and rnumel >= 256: configs = configs[:1] elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] - - if reduction_hint == ReductionHint.OUTER_TINY: - tiny_configs = [ - triton_config_reduction( - size_hints, - 2 * (256 // rnumel) if rnumel <= 256 else 1, - rnumel, - ) - ] - if max_autotune_enabled: - for tconfig in tiny_configs: - if tconfig not in configs: - configs.append(tconfig) - else: - configs = tiny_configs + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs + else: + if torch.version.hip: + # If autotune is enabled append tiny configs + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) for c in configs: # we don't need Rn_BLOCK for persistent reduction From c269eb31826b38b4eb4a01b5f7f8a3d6ad0c4c43 Mon Sep 17 00:00:00 2001 From: Sampsa Riikonen Date: Fri, 12 Sep 2025 15:54:24 +0300 Subject: [PATCH 07/28] pointwise autotuning returnz (#2636) removed the (erroneous?) check that disables autotuning for pointwise kernels (cherry picked from commit e3b8e25f46bfc64d8cb04e2f2bec326775b2fe88) (cherry picked from commit 10af207f8fd6b609584e317323499aea40bda2f9) (cherry picked from commit b9e01825bca59a201dffa0e18b63ebc79e502838) --- torch/_inductor/runtime/triton_heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 6d0383ebdd995..007b542484b8e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1901,7 +1901,7 @@ def pointwise( ] if len(size_hints) == 2: if ( - disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") From d13e652206dfbfa58fe4efc5b27ccd522abe213b Mon Sep 17 00:00:00 2001 From: Sampsa Riikonen Date: Wed, 17 Sep 2025 23:36:09 +0300 Subject: [PATCH 08/28] a few more configs for 2d grids (#2649) Added two nice grid configs for the 2d pointwise kernel cases for WRT5 workload. Confirmed that they were picked up when using max autotune. (cherry picked from commit f1eac493829d676f08484b47be72d3c5b74d2d9c) (cherry picked from commit 2e79001835de014004fc43ea09d9b1c2e0dd643e) (cherry picked from commit 04aa3e4f5525816a47389f5403d945a4b46ee508) --- torch/_inductor/runtime/triton_heuristics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 007b542484b8e..5cf0ca75dff14 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1913,6 +1913,8 @@ def pointwise( triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels + triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, From a13015c18382c569df201d8c658bb0e7e7476b17 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" <165712832+naromero77amd@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:30:17 -0500 Subject: [PATCH 09/28] [ROCm][inductor] Additional pointwise tunings (#2642) This config improves the performance of a 1D pointwise kernel by 20% as measured on MI350. (cherry picked from commit a7bac0ac90970ac28175dc76ded5bd937a3a2606) (cherry picked from commit 0bdb796fe7d652be400d3b683b4660d1fd130071) (cherry picked from commit af5f678b49c16a3e6ebaeb7623f74f34958bb8f1) --- torch/_inductor/runtime/triton_heuristics.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5cf0ca75dff14..b4d323d8edb86 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1899,6 +1899,15 @@ def pointwise( ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.append(triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1 + )) # 20% improvement if len(size_hints) == 2: if ( disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE From 3d716ebbb9ffdd92f1b8a2f4ab7d3e78239b7081 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 17 Sep 2025 21:26:30 +0000 Subject: [PATCH 10/28] pointwise config with MAX_BLOCK. (cherry picked from commit 16e82664843495d6452b3e1b29cdba9dd81477a5) (cherry picked from commit 8bd33f9f509a06ee65684c2513be055094e99e0c) --- torch/_inductor/runtime/triton_heuristics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index b4d323d8edb86..a1649515e3a48 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1897,6 +1897,9 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), *hinted_configs, ] # Additional reduction configs appended for ROCm builds From 7edd183a7361ba6fd2f6aa96fad0d5f6cf7bae0e Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 17 Sep 2025 22:07:15 +0000 Subject: [PATCH 11/28] Update triton_config method. (cherry picked from commit dfc15795680a7f7c444230f1a1a25643654ecc60) (cherry picked from commit 8f60456e7fba76d32ed1ed78a5d086b44ac60eea) --- torch/_inductor/runtime/triton_heuristics.py | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a1649515e3a48..0a1fdf867faf7 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1630,6 +1630,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None, ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -1686,9 +1689,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_warps if they are not hard passed to config + if num_warps is None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -1718,7 +1723,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: From adc6a454901a3a124dc4c86f9b6cfc6673246dcf Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 18 Sep 2025 00:51:09 +0000 Subject: [PATCH 12/28] Increase TRITON_MAX_BLOCK X value. (cherry picked from commit 666e81bc73f1d1ed3525fd9605408d1bf3149e80) (cherry picked from commit f6aaaf8bc2085c5742ae26dbd7c34d0699800203) --- torch/_inductor/runtime/hints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 3bc8df35a8389..2a045f1167e9c 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,7 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only From b1f1c364869e44a73ae29c0b1710b324f6f54288 Mon Sep 17 00:00:00 2001 From: AmdSampsa Date: Fri, 19 Sep 2025 11:33:11 +0000 Subject: [PATCH 13/28] even more configs (cherry picked from commit f97c7a91f36af6dd249e63809c44d21543ae7719) (cherry picked from commit db49466eef1a7e238f7d5e2a48edb7fa1e3aab2e) (cherry picked from commit 6e9b4ee81cd276089f3f4f06816ed9439ca5e9f8) (cherry picked from commit c36d85fcfd457025daa3721904dfb738cfe0b437) --- torch/_inductor/runtime/triton_heuristics.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 0a1fdf867faf7..0eaa0aa67e4a1 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1913,6 +1913,9 @@ def pointwise( triton_config_with_settings( size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 ), + triton_config_with_settings( + size_hints, 4096 # wrt: better than the max_block for some kernel + ), *hinted_configs, ] # Additional reduction configs appended for ROCm builds @@ -1935,10 +1938,12 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), From 26f81334c2d7556f5ae67ece9f79e7a9fc109a9f Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 23 Sep 2025 14:47:03 +0100 Subject: [PATCH 14/28] More ROCm conditionalisation (cherry picked from commit 0c52d012c1bb41b464a61d4f3a10d906a824dc61) (cherry picked from commit 83e453f49b4e253009c6f2aa8df3d8891abb55fc) --- torch/_inductor/runtime/triton_heuristics.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 0eaa0aa67e4a1..5b5a7970d45fd 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1928,8 +1928,11 @@ def pointwise( waves_per_eu=1 )) # 20% improvement if len(size_hints) == 2: + # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds + # ROCm has observed improvement by diverging here if ( - disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE + disable_pointwise_autotuning(inductor_meta) + or (torch.version.hip is None and tile_hint == TileHint.SQUARE) ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -1938,13 +1941,13 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), - triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels + triton_config_with_settings(size_hints, 64, 32), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), - triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels - triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more - triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels + triton_config_with_settings(size_hints, 128, 16), # +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # additional 10% more + triton_config_with_settings(size_hints, 32, 512), # +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, From c1f8e9912fa6179734789cb53223d4c6edfa1f97 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 25 Sep 2025 19:02:03 +0000 Subject: [PATCH 15/28] Lint. (cherry picked from commit dd990a38025fd82a3ae1e7a9068c256a9e08375b) (cherry picked from commit 0de435fa189d1c191cfbac63c04d0d1fc1c0a53e) --- torch/_inductor/runtime/triton_heuristics.py | 33 +++++++++++--------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5b5a7970d45fd..0acdfb286e996 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1914,25 +1914,24 @@ def pointwise( size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 ), triton_config_with_settings( - size_hints, 4096 # wrt: better than the max_block for some kernel + size_hints, + 4096, # wrt: better than the max_block for some kernel ), *hinted_configs, ] # Additional reduction configs appended for ROCm builds if torch.version.hip: - configs.append(triton_config_with_settings( - size_hints, - 2048, - num_warps=8, - num_stages=2, - waves_per_eu=1 - )) # 20% improvement + configs.append( + triton_config_with_settings( + size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1 + ) + ) # 20% improvement if len(size_hints) == 2: # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds # ROCm has observed improvement by diverging here if ( - disable_pointwise_autotuning(inductor_meta) - or (torch.version.hip is None and tile_hint == TileHint.SQUARE) + disable_pointwise_autotuning(inductor_meta) + or (torch.version.hip is None and tile_hint == TileHint.SQUARE) ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -1941,13 +1940,19 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), - triton_config_with_settings(size_hints, 64, 32), # better for some kernels + triton_config_with_settings( + size_hints, 64, 32 + ), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), - triton_config_with_settings(size_hints, 128, 16), # +10% for some kernels - triton_config_with_settings(size_hints, 128, 32), # additional 10% more - triton_config_with_settings(size_hints, 32, 512), # +30% for some kernels + triton_config_with_settings( + size_hints, 128, 16 + ), # +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # additional 10% more + triton_config_with_settings( + size_hints, 32, 512 + ), # +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, From 0a9dcddeebded27497f87867dc32490606d9956d Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 22 Aug 2025 16:50:49 +0000 Subject: [PATCH 16/28] Reduction heursitics improvements for ROCm (cherry picked from commit 9534cbd5a81a6268e97705fb5858c9e01306edac) (cherry picked from commit 189481ed7129c40c173108029535b95884ffd252) --- torch/_inductor/codegen/triton.py | 18 +- torch/_inductor/config.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 268 +++++++++++++++++-- 3 files changed, 259 insertions(+), 29 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e9c5b910ba02f..e7c873370e6c5 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1002,11 +1002,17 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + if torch.version.hip: + return f"tl.minimum({a}, {b})" + else: + return f"triton_helpers.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + if torch.version.hip: + return f"tl.maximum({a}, {b})" + else: + return f"triton_helpers.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1202,7 +1208,10 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + if torch.version.hip: + return f"tl.rsqrt({x})" + else: + return f"libdevice.rsqrt({x})" @staticmethod @maybe_upcast_float32() @@ -3227,8 +3236,9 @@ def codegen_body(self): loop_end = ( "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) + num_stages = ", num_stages = 2" if torch.version.hip else "" self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):" ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 0b58772fbf0b6..13a1e731892cc 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1135,7 +1135,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 if torch.version.hip else 16 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 0acdfb286e996..a2058d97bd7a8 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -615,7 +615,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 if torch.version.hip else 16 ): log.debug( "Skip config %s because of register spilling: %d", @@ -1779,6 +1779,8 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None, + dynamic_scale_rblock=True, ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -1822,7 +1824,18 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = InductorConfig( + cfg, + num_warps=num_warps, + num_stages=num_stages, + dynamic_scale_rblock=dynamic_scale_rblock, + ) + + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_config(numels: dict[str, int]) -> dict[str, int]: @@ -1833,7 +1846,9 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()} -def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1): +def triton_config_tiled_reduction( + size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None +): """ Construct a tile reduction triton config with some adjustment heuristics based on size_hints. Size_hints is a tuple of numels in @@ -1866,7 +1881,11 @@ def total_numel() -> int: num_warps = _num_warps(total_numel() // 256, min_num_warps=1) check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1]) check_max_block(cfg) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + return config def pointwise( @@ -1992,6 +2011,11 @@ def _reduction_configs( # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) + # Is max autotune enabled + max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ) + register_intensive = False MAX_R0_BLOCK = 2048 if ( @@ -2014,8 +2038,90 @@ def _reduction_configs( MAX_R0_BLOCK = 1024 register_intensive = True - contiguous_config = triton_config_reduction( - size_hints, + def make_config( + x, + r, + num_warps=None, + num_stages=1, + register_intensive=False, + dynamic_scale_rblock=True, + waves_per_eu=None, + ): + # For 3D case with tiling scores, create an adapted version + if "y" in size_hints: + assert "tiling_scores" in inductor_meta + return adapt_config_for_tiling( + size_hints, + inductor_meta["tiling_scores"], + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + waves_per_eu=waves_per_eu, + ) + else: + # For other cases, use the original function + return triton_config_reduction( + size_hints, + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + waves_per_eu=waves_per_eu, + dynamic_scale_rblock=dynamic_scale_rblock, + ) + + def outer_config_opt(): + # Default to 64 for vectorized loads + max_x_block, x_block = 256, 64 + load_factor = inductor_meta.get("num_load", 0) + x = size_hints["x"] + num_warps = None + + # Try to use all SMs with small x + if x <= 1024: + x_block = max(min(x // 128, 8), 2) + outer_r_block = min(rnumel, 64) + # Lower bound x = 1024, 1024 // 16 = 128 around # of SMs + elif x // 4096 <= 8: + x_block = 16 + outer_r_block = 512 // x_block + elif num_dynamic > 1: + # Lots of compute with multiple dynamic shape per loop iteration + # Larger RBLOCK minimizes loop iteration + outer_r_block = max(min((rnumel // 64), 64), 8) + elif num_dynamic == 1: + # Dynamic shapes introduce a lot register pressure for indexing + outer_r_block = ( + 1 + if load_factor >= 3 + else min(next_power_of_2(max(rnumel, 128) // 128), 8) + ) + else: + x_block = max(min(max_x_block, next_power_of_2(x // 4096)), x_block) + if load_factor < 4 or rnumel <= 128: + outer_r_block = 512 // x_block + else: + # Heavier reductions contain a lot more overhead per loop iteration + # We minimize the overhead by enlarging r block + if rnumel >= 2048: + outer_r_block = 64 + else: + outer_r_block = 32 + x_block = min(x_block, 32) + num_warps = 4 + + # Set register intensive to true by default as we try to maximize tiles with heuristic + return make_config( + x_block, + outer_r_block, + num_warps=num_warps, + register_intensive=register_intensive, + ) + + contiguous_config = make_config( 1, rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK, register_intensive=register_intensive, @@ -2029,27 +2135,141 @@ def _reduction_configs( min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) - if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"): - pass # skip all these cases + + outer_config = make_config(64, 8, register_intensive=register_intensive) + # TODO (paulzhan): Test heuristic on AMD and internal testing + # for correctness + if not torch.version.hip and not is_fbcode(): + outer_config = outer_config_opt() + + configs = [] + + if inductor_meta.get("add_persistent_rblock") and loads_and_red <= 8: + xnumel = max(4096 // rnumel, 1) + c = make_config( + xnumel, + rnumel, + register_intensive=register_intensive, + dynamic_scale_rblock=False, + ) + configs.append(c) + + result_configs = [] + + # For 3d tiling, default to more autotuning initially + if "y" in size_hints: + pass + elif max_autotune_enabled: + pass # skip all these cases elif reduction_hint == ReductionHint.INNER: - return [contiguous_config] + return configs + [contiguous_config] elif reduction_hint == ReductionHint.OUTER: - return [outer_config] + return configs + [outer_config] elif reduction_hint == ReductionHint.OUTER_TINY: - return [tiny_config] - if disable_pointwise_autotuning(inductor_meta): - return [triton_config_reduction(size_hints, 32, 128)] - return [ - contiguous_config, - outer_config, - tiny_config, - triton_config_reduction(size_hints, 64, 64), - triton_config_reduction(size_hints, 8, 512), - # halve the XBLOCK/Rn_BLOCK compared to outer_config - # TODO: this may only be beneficial when each iteration of the reduction - # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - triton_config_reduction(size_hints, 64, 4, num_warps=8), - ] + return configs + [tiny_config] + + # We continue here under the following conditions: + # - max_autotune_enabled is True + # - max_autotune_enabled is False and reduction_hint is NOT one of the above cases + result_configs = configs + [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + if torch.version.hip: + result_configs.extend( + [ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), + ] + ) + + return result_configs + + +def match_target_block_product( + size_hints, tiling_scores, target_block_product, min_block_size=1 +): + """ + Distribute block sizes across dimensions according to tiling scores, + aiming to match a target product of block sizes. + """ + total_score = sum(tiling_scores.values()) + if total_score == 0: + # just assume even score with no minimum block size + min_block_size = 1 + tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product) + + # First, give each coalescing dimension at least min_block_size + block_sizes = {} + relative_scores = {} + curr_block_product = 1 + + for dim, score in tiling_scores.items(): + if score == 0: + block_sizes[dim] = 1 + continue + + block_sizes[dim] = min_block_size + curr_block_product *= min_block_size + relative_scores[dim] = score / total_score + + # Scale up dimensions by their relative scores until we reach the target + while curr_block_product < target_block_product and len(relative_scores): + dim, score = max(relative_scores.items(), key=lambda item: item[1]) + + # Check if we've hit the max for this dimension + if ( + block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()] + or block_sizes[dim] >= size_hints[dim] + ): + del relative_scores[dim] + continue + + block_sizes[dim] *= 2 + relative_scores[dim] /= 2 + curr_block_product *= 2 + + return block_sizes + + +def adapt_config_for_tiling( + size_hints, + tiling_scores, + original_x, + original_r, + num_warps=None, + num_stages=1, + register_intensive=False, + persistent_reduction=False, + waves_per_eu=None, +) -> Config: + """ + Create an adapted configuration based on tiling scores, + redistributing the same total block size (x * r) according to tiling scores. + """ + assert all(s in tiling_scores for s in size_hints) + target_block_product = original_x * original_r + block_sizes = match_target_block_product( + size_hints, tiling_scores, target_block_product + ) + + return triton_config_tiled_reduction( + size_hints, + block_sizes["x"], + block_sizes["y"], + block_sizes["r0_"], + num_stages=num_stages, + register_intensive=register_intensive, + waves_per_eu=waves_per_eu, + ) def reduction( From 768ab8442cb0b1d52e7c878e441b63bce31faac0 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 2 Oct 2025 15:16:54 +0000 Subject: [PATCH 17/28] Add PropagateNan argument to minimum and maximum function. (cherry picked from commit 7eeb1ba4124addcaf52e7585369c586cdd9e35e0) (cherry picked from commit eea659c381d7e85f1536bd22701ab6b336c77015) --- torch/_inductor/codegen/triton.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e7c873370e6c5..aa0f7e1e73a73 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1003,14 +1003,14 @@ def relu(x): @staticmethod def minimum(a, b): if torch.version.hip: - return f"tl.minimum({a}, {b})" + return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)" else: return f"triton_helpers.minimum({a}, {b})" @staticmethod def maximum(a, b): if torch.version.hip: - return f"tl.maximum({a}, {b})" + return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)" else: return f"triton_helpers.maximum({a}, {b})" From 451c2b4ba41c607b8a8b33b07c19929a2b62ef06 Mon Sep 17 00:00:00 2001 From: Sampsa Riikonen Date: Tue, 14 Oct 2025 17:53:22 +0300 Subject: [PATCH 18/28] New WRT configs for autotuning (#2708) Reorganized slightly the adding of hard-coded autotuning configs. Fixed wrt1 configs. Added wrt2 & 3 configs. (cherry picked from commit e3e9a178e277ce54fd4287f3298a8388ed6e2d7e) (cherry picked from commit 6534df0b317e7780b2119fc780869a872126c9b1) --- torch/_inductor/runtime/triton_heuristics.py | 31 ++++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a2058d97bd7a8..2ca57bdc67afe 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1945,6 +1945,15 @@ def pointwise( size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1 ) ) # 20% improvement + configs += [ + triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? + triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel + triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 + # triton_poi_fused_index_put_new_zeros_49 + # triton_poi_fused_index_put_new_zeros_54 + ] if len(size_hints) == 2: # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds # ROCm has observed improvement by diverging here @@ -1963,7 +1972,7 @@ def pointwise( size_hints, 64, 32 ), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 - triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), triton_config_with_settings( size_hints, 128, 16 @@ -1976,6 +1985,17 @@ def pointwise( triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] + if torch.version.hip: + configs += [ # add here + ] + # bypass triton_config_with_settings -> triton_config logic + if "x" in size_hints and "y" in size_hints: + configs += [ + Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + ] + if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] @@ -2188,8 +2208,13 @@ def outer_config_opt(): [ make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), - ] - ) + make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 + ] + ) return result_configs From 7850a9c97813ff2687769efd9a6c4ff5ff749187 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Sat, 15 Nov 2025 01:29:20 +0000 Subject: [PATCH 19/28] Align TRITON_MAX_BLOCK code with upstream. --- torch/_inductor/runtime/hints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 2a045f1167e9c..ec981f4a786df 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -7,13 +7,14 @@ from enum import auto, Enum from typing import Optional, Union +import torch from torch.utils._triton import has_triton_package # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 8192, + "X": 8192 if torch.version.hip else 4096, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only From dbdb5542c2ae0f09415495c33bfd7d5d0f77bc53 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Sat, 15 Nov 2025 02:09:41 +0000 Subject: [PATCH 20/28] Fix indentation. --- torch/_inductor/runtime/triton_heuristics.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2ca57bdc67afe..02a783c01de99 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2203,18 +2203,18 @@ def outer_config_opt(): make_config(64, 4, num_warps=8), ] - if torch.version.hip: - result_configs.extend( - [ - make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), - make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), - make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 - make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 - make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 - make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 - make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 - ] - ) + if torch.version.hip: + result_configs.extend( + [ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), + make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 + ] + ) return result_configs From 83e3be03958a2728f66b3fc46076d56f2a518a0f Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Tue, 18 Nov 2025 21:14:56 +0000 Subject: [PATCH 21/28] Backport InductorConfig class. --- torch/_inductor/runtime/triton_heuristics.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 02a783c01de99..c9540575b26b4 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -69,6 +69,14 @@ ) +class InductorConfig(Config): + """Inductor-specific Triton config with additional control flags""" + + def __init__(self, *args, dynamic_scale_rblock=True, **kwargs): + super().__init__(*args, **kwargs) + self.dynamic_scale_rblock = dynamic_scale_rblock + + class NoTritonConfigsError(RuntimeError): pass From d235a1504f6702249dd72deef1a8f68ce991320a Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 19 Nov 2025 05:46:12 +0000 Subject: [PATCH 22/28] Patch get_args_with_constexprs --- torch/_inductor/runtime/triton_heuristics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c9540575b26b4..1d2383b9ecaa4 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -605,7 +605,8 @@ def _get_args_with_constexprs(self, args, launcher): # so we can sort them by index. constexpr_args: list[tuple[int, Any]] = [] for arg_name, arg_val in launcher.config.kwargs.items(): - constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) + if arg_name in self.fn.arg_names: + constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) constexpr_args.sort() new_args = [*args] From 627a5718c93f8c54fca6787f3167b2b454717226 Mon Sep 17 00:00:00 2001 From: Sampsa Riikonen Date: Fri, 14 Nov 2025 18:55:38 +0200 Subject: [PATCH 23/28] [NO CP] triton sanity check for 2D POI (#2798) Added a check that includes autotune configs for 2D POI only if their size is big enough. (cherry picked from commit a2b0fd7c9561ae520ec166a624abafbc2ebcd0d4) --- torch/_inductor/runtime/triton_heuristics.py | 31 +++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 1d2383b9ecaa4..c33420f342ad5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1995,15 +1995,30 @@ def pointwise( *hinted_configs, ] if torch.version.hip: - configs += [ # add here - ] - # bypass triton_config_with_settings -> triton_config logic if "x" in size_hints and "y" in size_hints: - configs += [ - Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 - Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 - Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 - ] + """add 2D tiling configs, but don't use triton_config_with_settings function + as it is buggy and might change the tiling randomly + """ + def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int): + # only add a tiling config if size is bigger than the tile + # check also for grid overflow + xgrid = (size_hints["x"] + xblock - 1) // xblock + ygrid = (size_hints["y"] + yblock - 1) // yblock + if xgrid > 2147483647: + return + if ygrid > 65535: + return + if size_hints["x"] < xblock: + return + if size_hints["y"] < yblock: + return + # all good, add the config + configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps, num_stages=num_stages)) + addConfig__(512, 8, 8,1 ) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + addConfig__(32, 128, 4, 1) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + addConfig__(64, 32, 8, 1) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + addConfig__(64, 256, 4, 1) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 + addConfig__(512, 64, 8, 1) # wri0: 58us: triton_poi_fused_clone_53 if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): From c48cc4c02b9bbddc24d6b3b3de368985f4d490b4 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 21 Nov 2025 01:32:56 +0000 Subject: [PATCH 24/28] Config specific to 1D pointwise kernels with atomic_add. --- torch/_inductor/codegen/common.py | 1 + torch/_inductor/codegen/triton.py | 2 ++ torch/_inductor/runtime/triton_heuristics.py | 11 +++++++++++ 3 files changed, 14 insertions(+) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 152d2ef36197f..98908d7cb696a 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1873,6 +1873,7 @@ def __init__( self.compute = IndentedBuffer() self.stores = IndentedBuffer() + self.atomic_add_found = False self.num_load = 0 self.num_reduction = 0 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index aa0f7e1e73a73..d8c5b35fe3972 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2294,6 +2294,7 @@ def store( elif mode is None: line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" elif mode == "atomic_add": + self.atomic_add_found = True line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')" else: raise NotImplementedError(f"store mode={mode}") @@ -3611,6 +3612,7 @@ def add_constexpr_arg(arg_name): "mutated_arg_names": mutated_args, "optimize_mem": optimize_mem, "no_x_dim": self.no_x_dim, + "atomic_add_found": self.atomic_add_found, "num_load": self.num_load, "num_reduction": self.num_reduction, **self.inductor_meta_common(), diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c33420f342ad5..fc3f3eedd6788 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1963,6 +1963,17 @@ def pointwise( # triton_poi_fused_index_put_new_zeros_49 # triton_poi_fused_index_put_new_zeros_54 ] + if inductor_meta.get("atomic_add_found"): + configs.extend( + [ + triton_config_with_settings( + size_hints, + 64, + num_warps=1, + num_stages=1, # 250% improvement + ) + ] + ) if len(size_hints) == 2: # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds # ROCm has observed improvement by diverging here From e824f40d6ff7c69998544c190f04276a7fe488e1 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 21 Nov 2025 01:38:46 +0000 Subject: [PATCH 25/28] Lint --- torch/_inductor/runtime/triton_heuristics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index fc3f3eedd6788..7f65e6221e6fa 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1957,9 +1957,9 @@ def pointwise( configs += [ triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel - triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), - # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, - # triton_poi_fused_index_put_new_zeros_45 + triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 # triton_poi_fused_index_put_new_zeros_49 # triton_poi_fused_index_put_new_zeros_54 ] @@ -1992,7 +1992,7 @@ def pointwise( size_hints, 64, 32 ), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 - triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), triton_config_with_settings( size_hints, 128, 16 @@ -2604,7 +2604,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): Compile a triton foreach kernel """ configs = [] - + # Naive autotuning path for num_warps if disable_pointwise_autotuning(inductor_meta) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") From ef5f375f81728e2fc449636a303d0dc188c23479 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 21 Nov 2025 02:08:55 +0000 Subject: [PATCH 26/28] No change, just re-ordering of configs to match upstream. --- torch/_inductor/runtime/triton_heuristics.py | 36 ++++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 7f65e6221e6fa..ca56692cde853 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1938,26 +1938,32 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), - triton_config_with_settings( - size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 - ), - triton_config_with_settings( - size_hints, - 4096, # wrt: better than the max_block for some kernel - ), *hinted_configs, ] # Additional reduction configs appended for ROCm builds if torch.version.hip: - configs.append( - triton_config_with_settings( - size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1 - ) - ) # 20% improvement + configs.extend( + [ + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, + 4096, # wrt: better than the max_block for some kernel + ), + triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1, # 20% improvement + ), + ] + ) configs += [ - triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? - triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel - triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), + triton_config_with_settings( + size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1 + ), # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, # triton_poi_fused_index_put_new_zeros_45 # triton_poi_fused_index_put_new_zeros_49 From 3ad2e711d449c9d8203953067edf15368364d43c Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 21 Nov 2025 02:15:09 +0000 Subject: [PATCH 27/28] Remove fbcode helper function. --- torch/_inductor/runtime/triton_heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ca56692cde853..4b50a7bcb789b 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2200,7 +2200,7 @@ def outer_config_opt(): outer_config = make_config(64, 8, register_intensive=register_intensive) # TODO (paulzhan): Test heuristic on AMD and internal testing # for correctness - if not torch.version.hip and not is_fbcode(): + if not torch.version.hip: outer_config = outer_config_opt() configs = [] From b1cdd5584626c1f0c2c6bad6b58272da6901e619 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 21 Nov 2025 02:22:57 +0000 Subject: [PATCH 28/28] Missing loads_and_red variable. --- torch/_inductor/runtime/triton_heuristics.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 4b50a7bcb789b..0f3d3e0762eb5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2079,11 +2079,10 @@ def _reduction_configs( register_intensive = False MAX_R0_BLOCK = 2048 - if ( - size_hints["x"] >= 1024 - and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0) - >= 10 - ): + loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( + "num_reduction", 0 + ) + if size_hints["x"] >= 1024 and loads_and_red >= 10: # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. # Consider load and reduction since load need move data into registers and # reduction needs an accumulator.