Skip to content

Commit 91302ea

Browse files
authored
[AMD][BACKEND] Switch to code object v5 (#5005)
Switches to code object v5 which requires to bump `rocm` to `6.2+` to avoid segfaults for `device_prints` and `tl.num_programs`. The added unit test covers the previous `segfault` with `device_prints` for 2d tensors. This is a preparation to test llvm/llvm-project@c4d8920
1 parent 2f40358 commit 91302ea

File tree

5 files changed

+33
-12
lines changed

5 files changed

+33
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,5 +253,5 @@ Supported Platforms:
253253
Supported Hardware:
254254

255255
- NVIDIA GPUs (Compute Capability 8.0+)
256-
- AMD GPUs (ROCm 5.2+)
256+
- AMD GPUs (ROCm 6.2+)
257257
- Under development: CPUs

python/test/unit/language/print_helper.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,17 @@ def kernel_print_pointer(X, Y, BLOCK: tl.constexpr):
9090
tl.device_print("ptr ", X + tl.arange(0, BLOCK))
9191

9292

93+
@triton.jit
94+
def kernel_print_2d_tensor(X, Y, BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr):
95+
off_x = tl.arange(0, BLOCK_SIZE_X)
96+
off_y = tl.arange(0, BLOCK_SIZE_Y)
97+
x = tl.load(X + off_x[:, None] * BLOCK_SIZE_Y + off_y[None, :])
98+
tl.device_print("", x)
99+
100+
93101
def test_print(func: str, data_type: str, device: str):
94102
N = 128 # This value should match with test_print in test_subprocess.py.
95-
# TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple
103+
# TODO(antiagainst): Currently the warp count is chosen to make sure we don't have multiple
96104
# threads printing duplicated messages due to broadcasting. Improve print op lowering logic
97105
# to filter out duplicated data range.
98106
num_warps = N // get_current_target_warp_size()
@@ -128,12 +136,18 @@ def test_print(func: str, data_type: str, device: str):
128136
kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N)
129137
elif func == "device_print_pointer":
130138
kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N)
139+
elif func == "device_print_2d_tensor":
140+
BLOCK_SIZE_X = num_warps
141+
BLOCK_SIZE_Y = get_current_target_warp_size()
142+
x_2d_tensor = x.reshape((BLOCK_SIZE_X, BLOCK_SIZE_Y))
143+
kernel_print_2d_tensor[(1, )](x_2d_tensor, y, num_warps=num_warps, BLOCK_SIZE_X=BLOCK_SIZE_X,
144+
BLOCK_SIZE_Y=BLOCK_SIZE_Y)
131145
else:
132146
assert f"Unknown kernel: {func}"
133147

134148
if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
135149
func != "print_multiple_args" and func != "device_print_multiple_args" and \
136-
func != "device_print_pointer" and func != "device_print_scalar":
150+
func != "device_print_pointer" and func != "device_print_scalar" and func != "device_print_2d_tensor":
137151
assert_close(y, x)
138152

139153
# Wait until driver complete all the jobs for the device_print, especially test_subprocess

python/test/unit/language/test_core.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,21 +2525,18 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
25252525
def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device):
25262526

25272527
@triton.jit
2528-
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.constexpr):
2528+
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
25292529
start_m = tl.program_id(0)
25302530
pid_n = tl.program_id(1)
2531+
num_pid_n = tl.num_programs(1)
25312532
local = INITIALIZE_PATCH
25322533
off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
2533-
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), NUM_PID_N):
2534+
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n):
25342535
off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
25352536
Xs = X + off_m[:, None] * N + off_n[None, :]
25362537
x = tl.load(Xs)
25372538
local = ACCUMULATE_PATCH
2538-
tl.store(Y + off_m * NUM_PID_N + pid_n, local)
2539-
# the following segfaults AMD backend following #3492
2540-
# really unclear why; the llvm-ir and kernel arguments are
2541-
# identical !
2542-
# tl.store(Y + off_m * tl.num_programs(1) + pid_n, local)
2539+
tl.store(Y + off_m * num_pid_n + pid_n, local)
25432540

25442541
initialize_patch = {
25452542
'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)',
@@ -2561,7 +2558,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25612558
BLOCK_M = 32
25622559
x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device)
25632560
y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device)
2564-
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n)
2561+
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N)
25652562
if not is_interpreter():
25662563
assert h.asm['ttgir'].count(
25672564
'"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work"

python/test/unit/language/test_subprocess.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import sys
55
from collections import Counter
66

7+
import triton
8+
79
import pytest
810

911
dir_path = os.path.dirname(os.path.realpath(__file__))
@@ -35,6 +37,7 @@ def is_interpreter():
3537
("device_print_pointer", "int32"),
3638
("device_print_negative", "int32"),
3739
("device_print_uint", "uint32"),
40+
("device_print_2d_tensor", "int32"),
3841
])
3942
def test_print(func_type: str, data_type: str, device: str):
4043
proc = subprocess.run(
@@ -101,6 +104,13 @@ def test_print(func_type: str, data_type: str, device: str):
101104
elif func_type == "device_print_pointer":
102105
for i in range(N):
103106
expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1
107+
elif func_type == "device_print_2d_tensor":
108+
warp_size = triton.runtime.driver.active.get_current_target().warp_size
109+
x_dim = N // warp_size
110+
y_dim = warp_size
111+
for x in range(x_dim):
112+
for y in range(y_dim):
113+
expected_lines[f"pid (0, 0, 0) idx ({x}, {y:2}): {(x * y_dim + y)}"] = 1
104114

105115
actual_lines = Counter()
106116
for line in outs:

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def make_llir(src, metadata, options):
321321
# Set various control constants on the LLVM module so that device
322322
# libraries can resolve references to them.
323323
amd.set_isa_version(llvm_mod, options.arch)
324-
amd.set_abi_version(llvm_mod, 400)
324+
amd.set_abi_version(llvm_mod, 500)
325325
amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False)
326326
amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True)
327327
amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False)

0 commit comments

Comments
 (0)