Skip to content

Commit 055cd74

Browse files
committed
Add RDNA3/RDNA4 tensor core support
Add MMA (Matrix Multiply-Accumulate) tensor core support for AMD RDNA3 and RDNA4 GPUs by splitting AMD GPU handling into separate RDNA and CDNA code paths with architecture-appropriate shapes. RDNA1 and RDNA2 are explicitly blocked with compile-time constraints as they have limited tensor core capabilities and require fallback implementations not yet implemented. This enables RDNA3 GPUs (RX 7000 series, W7900) and RDNA4 GPUs to use their tensor cores for matrix operations.
1 parent 218d502 commit 055cd74

File tree

1 file changed

+44
-8
lines changed

1 file changed

+44
-8
lines changed

max/kernels/src/layout/tensor_core.mojo

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ from sys import (
5454
size_of,
5555
)
5656

57+
from sys.info import (
58+
_is_amd_rdna,
59+
_is_amd_rdna1,
60+
_is_amd_rdna2,
61+
_is_amd_rdna3,
62+
_is_amd_rdna4,
63+
_is_amd_cdna,
64+
)
65+
66+
5767
from gpu import WARP_SIZE, lane_id, thread_idx
5868
from gpu.intrinsics import lop
5969
from gpu.memory import AddressSpace
@@ -1391,15 +1401,41 @@ fn get_mma_shape[
13911401
else:
13921402

13931403
@parameter
1394-
if accum_type is DType.float32 and input_type is DType.float32:
1395-
return shape_16x16x4
1396-
elif accum_type is DType.float32 and input_type.is_half_float():
1397-
return shape_16x16x16
1398-
elif accum_type is DType.float32 and input_type.is_float8():
1399-
return shape_16x16x32
1404+
if _is_amd_rdna():
1405+
1406+
@parameter
1407+
if _is_amd_rdna1() or _is_amd_rdna2():
1408+
constrained[
1409+
False,
1410+
(
1411+
"RDNA1/RDNA2 tensor core support requires fallback"
1412+
" paths (not yet implemented)"
1413+
),
1414+
]()
1415+
return shape_null
1416+
1417+
@parameter
1418+
if accum_type is DType.float32 and input_type is DType.float32:
1419+
return shape_16x16x16
1420+
elif accum_type is DType.float32 and input_type.is_half_float():
1421+
return shape_16x16x16
1422+
elif accum_type is DType.float32 and input_type.is_float8():
1423+
return shape_16x16x32
1424+
else:
1425+
constrained[False, "Unsupported RDNA mma shape."]()
1426+
return shape_null
14001427
else:
1401-
constrained[False, "Unsupported mma shape."]()
1402-
return shape_null
1428+
1429+
@parameter
1430+
if accum_type is DType.float32 and input_type is DType.float32:
1431+
return shape_16x16x4
1432+
elif accum_type is DType.float32 and input_type.is_half_float():
1433+
return shape_16x16x16
1434+
elif accum_type is DType.float32 and input_type.is_float8():
1435+
return shape_16x16x32
1436+
else:
1437+
constrained[False, "Unsupported CDNA mma shape."]()
1438+
return shape_null
14031439

14041440

14051441
@always_inline

0 commit comments

Comments
 (0)