Skip to content

Commit f2dc4ba

Browse files
chinmaydk99AMD
authored andcommitted
Fixing multi-kernel autotune for different size hints on ROCm
1 parent 962f13f commit f2dc4ba

File tree

6 files changed

+55
-14
lines changed

6 files changed

+55
-14
lines changed

test/inductor/test_multi_kernel.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from torch.testing._internal.common_utils import (
1717
instantiate_parametrized_tests,
1818
parametrize,
19-
skipIfRocm,
2019
skipIfXpu,
2120
)
2221
from torch.testing._internal.inductor_utils import (
@@ -108,8 +107,6 @@ def test_softmax(self, expect_multi_kernel=True):
108107
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
109108

110109
@requires_triton()
111-
# TODO: bobrenjc93 to fix multi-kernel for ROCM
112-
@skipIfRocm
113110
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
114111
@skipIfXpu(msg="https://github.com/intel/torch-xpu-ops/issues/2295")
115112
def test_triton_gemm(self):
@@ -133,13 +130,14 @@ def fn(x, y):
133130
# One for the first pass and one for the second pass.
134131
# We mainly care about the wrapper for the final pass here.
135132
wrapper_code = wrapper_code[-1]
136-
self.assertEqual(ref, act)
133+
if torch.version.hip:
134+
self.assertEqual(ref, act, atol=1e-3, rtol=1e-3)
135+
else:
136+
self.assertEqual(ref, act)
137137
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
138138

139139
@skipIfXpu(msg="https://github.com/intel/torch-xpu-ops/issues/2295")
140140
@requires_triton()
141-
# TODO: bobrenjc93 to fix multi-kernel for ROCM
142-
@skipIfRocm
143141
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
144142
def test_triton_relu_fused_gemm(self):
145143
def fn(x, y):
@@ -162,7 +160,11 @@ def fn(x, y):
162160
# One for the first pass and one for the second pass.
163161
# We mainly care about the wrapper for the final pass here.
164162
wrapper_code = wrapper_code[-1]
165-
self.assertEqual(ref, act)
163+
if torch.version.hip:
164+
self.assertEqual(ref, act, atol=1e-3, rtol=1e-3)
165+
else:
166+
self.assertEqual(ref, act)
167+
166168
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
167169

168170
@parametrize("force_kernel", (0, 1))

torch/_inductor/autotune_process.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,10 @@ class TensorMeta:
368368

369369
@classmethod
370370
def from_irnodes(
371-
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
371+
cls,
372+
irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]],
373+
*,
374+
hint_override: Optional[int] = None,
372375
) -> Union[TensorMeta, list[TensorMeta]]:
373376
if isinstance(irnodes, Sequence):
374377
result: list[Any] = [cls.from_irnodes(x) for x in irnodes]
@@ -390,14 +393,17 @@ def from_irnodes(
390393
sizes=V.graph.sizevars.size_hints(
391394
node.get_size(),
392395
fallback=config.unbacked_symint_fallback,
396+
hint_override=hint_override,
393397
),
394398
strides=V.graph.sizevars.size_hints(
395399
node.get_stride(),
396400
fallback=config.unbacked_symint_fallback,
401+
hint_override=hint_override,
397402
),
398403
offset=V.graph.sizevars.size_hint(
399404
node.get_layout().offset,
400405
fallback=config.unbacked_symint_fallback,
406+
hint_override=hint_override,
401407
),
402408
name=node.get_name(),
403409
)

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
SequentialComboKernelGrid,
2020
)
2121
from ..scheduler import BaseSchedulerNode
22-
from ..utils import Placeholder, triton_version_uses_attrs_dict
22+
from ..utils import is_rocm, Placeholder, triton_version_uses_attrs_dict
2323
from ..virtualized import V
2424
from .common import (
2525
ArgName,
@@ -742,10 +742,13 @@ def kernel_benchmark_extra_args(self) -> list[str]:
742742
continue
743743
# pyrefly: ignore [missing-argument]
744744
if not tree.is_reduction or sub_kernel.inside_reduction:
745+
meta_hint = sub_kernel.hint_override if is_rocm() else None
745746
extra_args.append(
746747
str(
747748
V.graph.sizevars.size_hint(
748-
tree.numel, fallback=config.unbacked_symint_fallback
749+
tree.numel,
750+
fallback=config.unbacked_symint_fallback,
751+
hint_override=meta_hint,
749752
)
750753
)
751754
)

torch/_inductor/select_algorithm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1495,10 +1495,12 @@ def call_kernel(
14951495
wrapper.generate_workspace_deallocation(self.workspace_arg)
14961496

14971497
def kernel_benchmark_extra_args(self) -> list[str]:
1498+
meta_hint = self.hint_override if torch.version.hip else None
14981499
return [
14991500
str(x)
15001501
for x in self.grid_fn(
1501-
*V.graph.sizevars.size_hints(self.call_sizes), self.meta
1502+
*V.graph.sizevars.size_hints(self.call_sizes, hint_override=meta_hint),
1503+
self.meta,
15021504
)
15031505
]
15041506

torch/_inductor/template_heuristics/triton.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,15 +646,27 @@ def _get_exceeding_shared_memory_checker(
646646
If the device does not report available shared memory, returns None.
647647
"""
648648

649+
from ..utils import get_gpu_shared_memory
650+
651+
sm_available = None
652+
649653
try:
650654
device = torch.cuda.current_device()
651655
props = torch.cuda.get_device_properties(device)
652656
if not hasattr(props, "shared_memory_per_block_optin"): # for NVidia GPUs
653657
return None
654658
sm_available = int(props.shared_memory_per_block_optin)
655659
except Exception:
656-
# If CUDA is not available or properties cannot be queried, return None
657-
return None
660+
pass
661+
662+
# ROCm specific logic to get shared memory
663+
if torch.version.hip and sm_available is None:
664+
try:
665+
sm_available = get_gpu_shared_memory()
666+
if sm_available == 0:
667+
return None
668+
except Exception:
669+
return None
658670

659671
# TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation.
660672
def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool:
@@ -1318,6 +1330,7 @@ def _finalize_mm_configs(
13181330
waves_per_eu,
13191331
matrix_instr_nonkdim,
13201332
kpack,
1333+
conf.hint_override,
13211334
)
13221335

13231336
# Check if gemm specific arg exists - add to key if does
@@ -1344,7 +1357,12 @@ def _finalize_mm_configs(
13441357
}
13451358
if group_m is not None:
13461359
kwargs["GROUP_M"] = group_m
1347-
yield self.triton_config(**kwargs)
1360+
1361+
tc = self.triton_config(**kwargs)
1362+
# Preserve hint_override for multi-kernel support
1363+
if hasattr(conf, "hint_override") and conf.hint_override is not None:
1364+
tc.hint_override = conf.hint_override
1365+
yield tc
13481366

13491367
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
13501368
flex_attn_fwd_configs: list[FlexConfig] = []
@@ -1674,6 +1692,12 @@ def _convert_config_to_template_kwargs(
16741692
group_m = triton_config.kwargs.get("GROUP_M", 8)
16751693
options_dict["GROUP_M"] = group_m
16761694

1695+
# Keep ROCm multi-kernel size bucket attached to the config
1696+
if torch.version.hip and "hint_override" not in options_dict:
1697+
hint_override = getattr(triton_config, "hint_override", None)
1698+
if hint_override is not None:
1699+
options_dict["hint_override"] = hint_override
1700+
16771701
return options_dict
16781702

16791703
def _get_acc_type(self, dtype: torch.dtype) -> str:

torch/_inductor/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3030,6 +3030,10 @@ def is_gpu(device: Optional[str]) -> bool:
30303030
return device in GPU_TYPES
30313031

30323032

3033+
def is_rocm() -> bool:
3034+
return torch.version.hip is not None
3035+
3036+
30333037
def device_need_guard(device: str) -> bool:
30343038
return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now
30353039

0 commit comments

Comments
 (0)