Skip to content

Commit ab86bc0

Browse files
committed
Add RDNA3/4 BF16/FP16 WMMA support
Add BF16/FP16 WWMA support for RDNA3/4 GPUs using WMMA intrinsics. We handle packed size (8,8,32,32) operations that require splitting into multiple 16x16x16 WMMA calls. RDNA WMMA intrinsics support: - Size 4 (single wave): Direct 16x16x16 operation - Size (8,8,32,32) (Wave32 packed): Split into two 16x16x16 operations The size 8 BF16/FP16 case requires splitting: 1. Each WMMA operates on 4 elements at a time 2. Use @parameter loop to split into 2 iterations (offsets 0 and 4) 3. We combine results with SIMD.insert() for each iteration FP8 support requires more work, so we annotate it and prepare a branch for future work for it.
1 parent 8222270 commit ab86bc0

File tree

1 file changed

+92
-23
lines changed

1 file changed

+92
-23
lines changed

mojo/stdlib/stdlib/gpu/mma.mojo

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ from sys.info import (
2121
CompilationTarget,
2222
_cdna_4_or_newer,
2323
_is_amd_rdna,
24+
_is_amd_rdna3,
25+
_is_amd_rdna4,
2426
is_amd_gpu,
2527
)
2628

@@ -90,17 +92,39 @@ fn _has_shape[
9092

9193
@always_inline
9294
fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
93-
"""AMD RDNA3+ WMMA implementation for matrix multiplication.
94-
95-
RDNA3/4 GPUs use WMMA instructions.
96-
Per https://gpuopen.com/learn/wmma_on_rdna3/
97-
the following intrinsics are supported:
98-
- llvm.amdgcn.wmma.f32.16x16x16.f16
99-
- llvm.amdgcn.wmma.f32.16x16x16.bf16
100-
- llvm.amdgcn.wmma.f16.16x16x16.f16
101-
- llvm.amdgcn.wmma.bf16.16x16x16.bf16
102-
- llvm.amdgcn.wmma.i32.16x16x16.iu8
103-
- llvm.amdgcn.wmma.i32.16x16x16.iu4
95+
"""Performs AMD RDNA3+ WMMA (Wave Matrix Multiply-Accumulate) operations.
96+
97+
This function implements matrix multiply-accumulate operations for AMD RDNA3+
98+
consumer GPUs using WMMA instructions. WMMA was introduced in RDNA3 and is not
99+
available on RDNA1/2 hardware.
100+
101+
Supported operations by RDNA generation:
102+
103+
RDNA3+ (all operations):
104+
- F32 = F16 * F16 + F32 (16x16x16 shape)
105+
- F32 = BF16 * BF16 + F32 (16x16x16 shape)
106+
107+
RDNA4 additional operations:
108+
- F32 = FP8 * FP8 + F32 (16x16x32 shape, native hardware support)
109+
110+
Args:
111+
d: Output accumulator SIMD vector (modified in-place).
112+
a: First input matrix as SIMD vector.
113+
b: Second input matrix as SIMD vector.
114+
c: Accumulator matrix as SIMD vector.
115+
116+
Note:
117+
RDNA WMMA supports both size 4 (single 16x16x16 operation) and size 8
118+
(double wave, split into two 16x16x16 operations). Size 8 for half-precision
119+
represents two operations packed together. Type and shape validation
120+
is performed by get_intrinsic_name() which calls _unsupported_mma_op()
121+
for invalid combinations.
122+
123+
References:
124+
- RDNA3 WMMA: https://gpuopen.com/learn/wmma_on_rdna3/
125+
- RDNA3 ISA: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf
126+
- RDNA4 ISA: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf
127+
- Section 7.5 (8-bit Math) for FP8/BF8 details
104128
"""
105129

106130
@parameter
@@ -111,22 +135,67 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
111135
# F32 = BF16 * BF16 + F32 (16x16x16)
112136
# ===------------------------------------------------------------------===#
113137
@parameter
114-
if (
115-
_has_type[
116-
(DType.float16, DType.float16, DType.float32, DType.float32)
117-
](a.dtype, b.dtype, c.dtype, d.dtype)
118-
or _has_type[
119-
(DType.bfloat16, DType.bfloat16, DType.float32, DType.float32)
120-
](a.dtype, b.dtype, c.dtype, d.dtype)
121-
) and _has_shape[4](a.size, b.size, c.size, d.size):
122-
alias type_name = "f16" if a.dtype is DType.float16 else "bf16"
123-
return "llvm.amdgcn.wmma.f32.16x16x16." + type_name
138+
if _has_type[
139+
(DType.float16, DType.float16, DType.float32, DType.float32)
140+
](a.dtype, b.dtype, c.dtype, d.dtype) or _has_type[
141+
(DType.bfloat16, DType.bfloat16, DType.float32, DType.float32)
142+
](
143+
a.dtype, b.dtype, c.dtype, d.dtype
144+
):
145+
146+
@parameter
147+
if _has_shape[4](a.size, b.size, c.size, d.size):
148+
alias type_name = "f16" if a.dtype is DType.float16 else "bf16"
149+
return "llvm.amdgcn.wmma.f32.16x16x16." + type_name
150+
elif _has_shape[(8, 8, 32, 32)](a.size, b.size, c.size, d.size):
151+
alias type_name = "f16" if a.dtype is DType.float16 else "bf16"
152+
return "llvm.amdgcn.wmma.f32.16x16x16." + type_name
153+
else:
154+
_unsupported_mma_op(d, a, b, c)
155+
return ""
156+
elif a.dtype in [
157+
DType.float8_e4m3fn,
158+
DType.float8_e4m3fnuz,
159+
DType.float8_e5m2,
160+
DType.float8_e5m2fnuz,
161+
] or b.dtype in [
162+
DType.float8_e4m3fn,
163+
DType.float8_e4m3fnuz,
164+
DType.float8_e5m2,
165+
DType.float8_e5m2fnuz,
166+
]:
167+
_unsupported_mma_op(d, a, b, c)
168+
return ""
124169
else:
125170
_unsupported_mma_op(d, a, b, c)
126171
return ""
127172

128-
var r = llvm_intrinsic[get_intrinsic_name(), SIMD[c.dtype, c.size]](a, b, c)
129-
d = rebind[type_of(d)](r)
173+
@parameter
174+
if a.size == 8 and b.size == 8 and c.size == 8 and d.size == 8:
175+
alias intrinsic_name = get_intrinsic_name()
176+
var r = llvm_intrinsic[intrinsic_name, SIMD[c.dtype, 8]](a, b, c)
177+
d = rebind[__type_of(d)](r)
178+
elif a.size == 8 and b.size == 8 and c.size == 32 and d.size == 32:
179+
var result = c.copy()
180+
alias intrinsic_name = get_intrinsic_name()
181+
182+
@parameter
183+
for i in range(2):
184+
alias offset = i * 4
185+
var a_slice = a.slice[4, offset=offset]()
186+
var b_slice = b.slice[4, offset=offset]()
187+
var c_slice = c.slice[4, offset=offset]()
188+
var r = llvm_intrinsic[intrinsic_name, SIMD[c.dtype, 4]](
189+
a_slice, b_slice, c_slice
190+
)
191+
result = result.insert[offset=offset](r)
192+
193+
d = rebind[__type_of(d)](result)
194+
else:
195+
var r = llvm_intrinsic[get_intrinsic_name(), SIMD[c.dtype, c.size]](
196+
a, b, c
197+
)
198+
d = rebind[__type_of(d)](r)
130199

131200

132201
@fieldwise_init

0 commit comments

Comments
 (0)