Skip to content

Commit 355dc47

Browse files
fsx950223xinyazhangantiagainst
authored
[AMD] Add HIP AOT support to compile.py tool (#7007)
This commit adds HIP AOT compilation support to the `compile.py` tool. It allows compiling Triton kernels into a `.h` and `.cpp` file that can be integrated into applications. Linking via `link.py` is not yet enabled and a task for later. --------- Co-authored-by: Xinya Zhang <[email protected]> Co-authored-by: Lei Zhang <[email protected]>
1 parent 09649e2 commit 355dc47

File tree

8 files changed

+200
-22
lines changed

8 files changed

+200
-22
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ jobs:
101101
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
102102
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
103103
cd python/test/unit
104-
pytest --capture=tee-sys -rfs -n 12 language runtime \
104+
pytest --capture=tee-sys -rfs -n 12 language runtime tools \
105105
--ignore=language/test_line_info.py \
106106
--ignore=test_debug.py
107107
# TODO: uncomment

python/test/unit/tools/test_aot.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import glob
22
import os
3+
import pytest
4+
import re
35
import subprocess
46
import sys
57
import tempfile
@@ -9,6 +11,7 @@
911
import triton
1012
from triton.backends.compiler import GPUTarget
1113
from triton.backends.nvidia.driver import include_dirs, library_dirs
14+
from triton._internal_testing import is_cuda, is_hip
1215

1316
kernel_utils_src = """
1417
import triton
@@ -273,6 +276,20 @@ def generate_matmul_test_data(dir, M, N, K):
273276
return a, b, a_path, b_path, c_path
274277

275278

279+
def check_hasco_binary_str(tmp_dir: str, dtype: str):
280+
# Linking is not yet enabled on HIP backend so just check compilation for now.
281+
h_files = glob.glob(f"matmul_{dtype}.*.h", root_dir=tmp_dir)
282+
cpp_files = glob.glob(f"matmul_{dtype}.*.cpp", root_dir=tmp_dir)
283+
assert len(h_files) == 1, "Expected one .h file"
284+
assert len(cpp_files) == 1, "Expected one .cpp file"
285+
pattern = re.compile(r'HSACO_NAME\[(\d+)\]')
286+
with open(os.path.join(tmp_dir, cpp_files[0]), "r") as cpp_file:
287+
content = cpp_file.read()
288+
matches = pattern.findall(content)
289+
assert len(matches) == 1, "Expected one HSACO_NAME definition"
290+
assert int(matches[0]) > 16, "Expected valid HSACO object binary string"
291+
292+
276293
# Test edge case where the provided kernel signature has no specializations
277294
def test_compile_link_matmul_no_specialization():
278295
np.random.seed(3)
@@ -283,6 +300,10 @@ def test_compile_link_matmul_no_specialization():
283300

284301
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
285302
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
303+
if is_hip():
304+
check_hasco_binary_str(tmp_dir, dtype)
305+
return
306+
286307
link_aot_kernels(tmp_dir)
287308

288309
# compile test case
@@ -314,6 +335,9 @@ def test_compile_link_matmul():
314335

315336
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
316337
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":16", ":16")])
338+
if is_hip():
339+
check_hasco_binary_str(tmp_dir, dtype)
340+
return
317341
link_aot_kernels(tmp_dir)
318342

319343
# compile test case
@@ -345,6 +369,10 @@ def test_launcher_has_no_available_kernel():
345369

346370
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
347371
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":1", ":1")])
372+
if is_hip():
373+
check_hasco_binary_str(tmp_dir, dtype)
374+
return
375+
348376
link_aot_kernels(tmp_dir)
349377

350378
# compile test case
@@ -371,6 +399,7 @@ def test_launcher_has_no_available_kernel():
371399
assert "kernel launch failed" in result.stderr
372400

373401

402+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
374403
def test_compile_link_autotune_matmul():
375404
np.random.seed(3)
376405

@@ -419,19 +448,25 @@ def test_compile_link_autotune_matmul():
419448
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4)
420449

421450

422-
def test_ttgir_to_ptx():
451+
def test_ttgir_to_asm():
423452
src = """
424-
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} {
425-
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
453+
module attributes {{"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {warp_size} : i32, "ttg.num-ctas" = 1 : i32}} {{
454+
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {{
426455
tt.return
427-
}
428-
}
456+
}}
457+
}}
429458
"""
459+
target = GPUTarget("hip", "gfx942", 64) if is_hip() else GPUTarget("cuda", 80, 32)
430460
with tempfile.TemporaryDirectory() as tmp_dir:
431461
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
432462
with open(kernel_path, "w") as fp:
433-
fp.write(src)
434-
k = triton.compile(kernel_path, target=GPUTarget("cuda", 80, 32))
435-
ptx = k.asm["ptx"]
436-
assert ".target sm_80" in ptx
437-
assert ".address_size 64" in ptx
463+
fp.write(src.format(warp_size=target.warp_size))
464+
k = triton.compile(kernel_path, target=target)
465+
if is_cuda():
466+
ptx = k.asm["ptx"]
467+
assert ".target sm_80" in ptx
468+
assert ".address_size 64" in ptx
469+
elif is_hip():
470+
amdgcn = k.asm["amdgcn"]
471+
assert '.amdgcn_target "amdgcn-amd-amdhsa--gfx942"' in amdgcn
472+
assert '.wavefront_size: 64' in amdgcn

python/triton/backends/driver.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@ class DriverBase(metaclass=ABCMeta):
1515
def is_active(self):
1616
pass
1717

18+
@abstractmethod
19+
def map_python_to_cpp_type(self, ty: str) -> str:
20+
"""
21+
Converts a Triton type string to its corresponding C++ type string for this backend.
22+
23+
Args:
24+
ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
25+
26+
Returns:
27+
str: The C++ type string.
28+
"""
29+
pass
30+
1831
@abstractmethod
1932
def get_current_target(self):
2033
pass

python/triton/tools/compile.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,29 @@
33
import importlib.util
44
import sys
55
from argparse import ArgumentParser
6+
from dataclasses import dataclass
67
from pathlib import Path
78
from typing import List
89

910
import triton
1011
import triton.backends
11-
from triton.backends.nvidia.driver import ty_to_cpp
12+
13+
14+
@dataclass
15+
class CompileArgs:
16+
'''
17+
A class to contain arguments from command-line parser.
18+
'''
19+
path: str = ''
20+
kernel_name: str = ''
21+
signature: str = ''
22+
grid: str = ''
23+
target: str | None = None
24+
num_warps: int = 1
25+
num_stages: int = 3
26+
out_name: str | None = None
27+
out_path: Path | None = None
28+
1229

1330
desc = """
1431
Triton ahead-of-time compiler:
@@ -36,23 +53,31 @@
3653
used to run this `compile.py` script
3754
"""
3855

39-
if __name__ == "__main__":
4056

57+
def main():
4158
# command-line arguments
4259
parser = ArgumentParser(description=desc)
4360
parser.add_argument("path",
4461
help="Path to Python source containing desired kernel in its scope. File will be executed.")
4562
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
4663
required=True)
64+
parser.add_argument(
65+
"--target", "-t", type=str, default=None,
66+
help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
67+
"e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
4768
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
4869
parser.add_argument("--num-stages", "-ns", type=int, default=3,
4970
help="Number of stages (meta-parameter of the kernel)")
5071
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
5172
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
5273
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
5374
parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
54-
args = parser.parse_args()
75+
cli_args = parser.parse_args()
76+
args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
77+
compile_kernel(args)
5578

79+
80+
def compile_kernel(args: CompileArgs):
5681
out_name = args.out_name if args.out_name else args.kernel_name
5782
out_path = args.out_path if args.out_path else Path(out_name)
5883

@@ -108,9 +133,15 @@ def constexpr(s):
108133
assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
109134
attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
110135
src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
111-
opts = {"num_warps": args.num_warps, "num_stages": args.num_stages}
112-
ccinfo = triton.compile(src, options=opts)
113-
if ccinfo.metadata.global_scratch_size > 0:
136+
137+
target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
138+
if args.target else triton.runtime.driver.active.get_current_target()
139+
backend = triton.compiler.make_backend(target)
140+
kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
141+
options = backend.parse_options(kwargs)
142+
ccinfo = triton.compile(src, target=target, options=options.__dict__)
143+
144+
if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
114145
raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
115146

116147
arg_names = []
@@ -136,8 +167,12 @@ def constexpr(s):
136167
if hints.get((i, ), None) == 16:
137168
suffix += 'd'
138169
func_name = '_'.join([out_name, sig_hash, suffix])
139-
asm = ccinfo.asm["cubin"] # store binary data once
170+
asm = ccinfo.asm[backend.binary_ext] # store binary data once
171+
140172
hex_ = str(binascii.hexlify(asm))[2:-1]
173+
174+
ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
175+
141176
params = {
142177
"kernel_name": func_name,
143178
"triton_kernel_name": args.kernel_name,
@@ -156,7 +191,18 @@ def constexpr(s):
156191
"gridZ": grid[2],
157192
"_placeholder": "",
158193
}
159-
for ext in ['h', 'c']:
160-
template_path = Path(__file__).parent / "extra" / "cuda" / f"compile.{ext}"
161-
with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp:
162-
fp.write(Path(template_path).read_text().format(**params))
194+
output_files = []
195+
backend_name = target.backend
196+
template_dir = Path(__file__).parent / "extra" / backend_name
197+
for template_path in template_dir.glob('compile.*'):
198+
ext = template_path.suffix
199+
output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
200+
with output_file.open("w") as fp:
201+
fp.write(template_path.read_text().format(**params))
202+
output_files.append(output_file)
203+
204+
return func_name, output_files
205+
206+
207+
if __name__ == "__main__":
208+
main()

third_party/amd/backend/driver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,9 @@ def is_active():
585585
except ImportError:
586586
return False
587587

588+
def map_python_to_cpp_type(self, ty: str) -> str:
589+
return ty_to_cpp(ty)
590+
588591
def get_current_target(self):
589592
device = self.get_current_device()
590593
device_properties = self.utils.get_device_properties(device)

third_party/amd/tools/hip/compile.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
/* clang-format off */
5+
#include <stdio.h>
6+
#include <stdint.h>
7+
#include <inttypes.h>
8+
#include <string.h>
9+
#include <hip/hip_runtime.h>
10+
11+
// helpers to check for hip errors
12+
#define HIP_CHECK(ans) {{\
13+
gpuAssert((ans), __FILE__, __LINE__);\
14+
}}\
15+
16+
static inline void gpuAssert(hipError_t code, const char *file, int line) {{
17+
if (code != hipSuccess) {{
18+
const char *prefix = "Triton Error [HIP]: ";
19+
const char *str;
20+
hipDrvGetErrorString(code, &str);
21+
char err[1024] = {{0}};
22+
strcat(err, prefix);
23+
strcat(err, str);
24+
printf("%s\\n", err);
25+
exit(code);
26+
}}
27+
}}
28+
29+
// globals
30+
#define HSACO_NAME {kernel_name}_hsaco
31+
hipModule_t {kernel_name}_mod = nullptr;
32+
hipFunction_t {kernel_name}_func = nullptr;
33+
unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
34+
35+
36+
void unload_{kernel_name}(void) {{
37+
HIP_CHECK(hipModuleUnload({kernel_name}_mod));
38+
}}
39+
40+
41+
void load_{kernel_name}() {{
42+
int dev = 0;
43+
void *bin = (void *)&HSACO_NAME;
44+
int shared = {shared};
45+
HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
46+
HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
47+
}}
48+
49+
/*
50+
{kernel_docstring}
51+
*/
52+
hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
53+
if ({kernel_name}_func == nullptr)
54+
load_{kernel_name}();
55+
unsigned int gX = {gridX};
56+
unsigned int gY = {gridY};
57+
unsigned int gZ = {gridZ};
58+
hipDeviceptr_t global_scratch = 0;
59+
void *args[{num_args}] = {{ {arg_pointers} }};
60+
// TODO: shared memory
61+
if(gX * gY * gZ > 0)
62+
return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
63+
else
64+
return hipErrorInvalidValue;
65+
}}

third_party/amd/tools/hip/compile.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include <hip/hip_runtime.h>
7+
#include <inttypes.h>
8+
#include <stdint.h>
9+
#include <stdio.h>
10+
11+
void unload_{kernel_name}(void);
12+
void load_{kernel_name}(void);
13+
hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});

third_party/nvidia/backend/driver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,9 @@ def is_active():
683683
except ImportError:
684684
return False
685685

686+
def map_python_to_cpp_type(self, ty: str) -> str:
687+
return ty_to_cpp(ty)
688+
686689
def get_benchmarker(self):
687690
from triton.testing import do_bench
688691
return do_bench

0 commit comments

Comments
 (0)