Skip to content

Commit 67af519

Browse files
authored
[AMD][Test] Enable test_override_arch (#7691)
1 parent 0ec07e7 commit 67af519

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

python/test/unit/language/test_core.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7134,11 +7134,13 @@ def mul_add(data):
71347134
# -----------------------
71357135

71367136

7137-
@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90"])
7137+
@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90", "gfx942", "gfx950", "gfx1200"])
71387138
@pytest.mark.parametrize("env_var_override", [False, True])
71397139
def test_override_arch(arch, env_var_override, device):
7140-
if not is_cuda():
7141-
pytest.skip('arch only for CUDA')
7140+
if arch.startswith("sm") and not is_cuda():
7141+
pytest.skip(f"{arch} arch only for CUDA")
7142+
elif arch.startswith("gfx") and not is_hip():
7143+
pytest.skip(f"{arch} arch only for HIP")
71427144

71437145
@triton.jit
71447146
def simple(data, out):
@@ -7149,15 +7151,31 @@ def simple(data, out):
71497151
data = torch.randn((128, ), device=device, dtype=torch.float32)
71507152
out = torch.empty_like(data)
71517153

7152-
if env_var_override:
7153-
os.environ["TRITON_OVERRIDE_ARCH"] = str(arch)
7154-
h = simple[(1, )](data, out)
7155-
os.environ.pop("TRITON_OVERRIDE_ARCH")
7156-
else:
7157-
h = simple[(1, )](data, out, arch=arch)
7158-
torch.testing.assert_close(data * 1.5 + 1.0, out)
7159-
ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"])
7160-
assert ttgir_cc.group(1) == arch[2:]
7154+
if is_cuda():
7155+
if env_var_override:
7156+
os.environ["TRITON_OVERRIDE_ARCH"] = str(arch)
7157+
h = simple[(1, )](data, out)
7158+
os.environ.pop("TRITON_OVERRIDE_ARCH")
7159+
else:
7160+
h = simple[(1, )](data, out, arch=arch)
7161+
torch.testing.assert_close(data * 1.5 + 1.0, out)
7162+
ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"])
7163+
assert ttgir_cc.group(1) == arch[2:]
7164+
elif is_hip():
7165+
# For HIP, the generated kernel is a binary containing the final ISA. So we cannot run
7166+
# them like CUDA side if the chip doesn't match. Here we just check generated ISA.
7167+
if env_var_override:
7168+
os.environ["TRITON_OVERRIDE_ARCH"] = str(arch)
7169+
h = simple.warmup(data, out, grid=(1, ))
7170+
os.environ.pop("TRITON_OVERRIDE_ARCH")
7171+
else:
7172+
h = simple.warmup(data, out, arch=arch, grid=(1, ))
7173+
ttgir_gfx = re.search(r'hip:(\w+)', h.asm["ttgir"])
7174+
ttgir_warp = re.search(r'"ttg.threads-per-warp" = (\d+)', h.asm["ttgir"])
7175+
amdgcn_gfx = re.search(r'.amdgcn_target "amdgcn-amd-amdhsa--(\w+)"', h.asm["amdgcn"])
7176+
assert ttgir_gfx.group(1) == arch
7177+
assert int(ttgir_warp.group(1)) == (32 if arch == "gfx1200" else 64)
7178+
assert amdgcn_gfx.group(1) == arch
71617179

71627180

71637181
# -----------------------

0 commit comments

Comments
 (0)