Skip to content

Commit 1495116

Browse files
authored
[AMD] Add missing i16 for wmma and disable some tests (#4843)
1 parent 33c0c1c commit 1495116

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3330,7 +3330,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
33303330
if is_hip():
33313331
# hip does not support tf32 precision, so use ieee for all tests
33323332
input_precision = "ieee"
3333-
if "gfx11" in triton.runtime.driver.active.get_current_target().arch:
3333+
arch = triton.runtime.driver.active.get_current_target().arch
3334+
if "gfx11" in arch or "gfx12" in arch:
33343335
if in_dtype_str == "float32":
33353336
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d")
33363337
if out_dtype_str == "float16":

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ std::string getTypeStr(Type ty) {
162162
scalarName = "bf16";
163163
} else if (ty.isInteger(32)) {
164164
scalarName = "i32";
165+
} else if (ty.isInteger(16)) {
166+
scalarName = "i16";
165167
} else if (ty.isInteger(8)) {
166168
scalarName = "iu8";
167169
} else if (ty.isInteger(4)) {

0 commit comments

Comments
 (0)