Skip to content

Commit bbeb3bf

Browse files
committed
Add RDNA4 native FP8 tensor core support
Add native FP8/BF8 WMMA support for RDNA4 GPUs, which have hardware support for FP8 matrix operations via the llvm.amdgcn.wmma.f32.16x16x32.fp8 intrinsic. FP8 support by RDNA generation: - RDNA4: Native FP8/BF8 WMMA (16x16x32 shape) - RDNA3: Requires emulation via FP16/BF16 conversion - future work - RDNA1/2: Not supported (requires fallback paths) The FP8 intrinsic follows the established RDNA WMMA pattern: llvm.amdgcn.wmma.<accum>.<M>x<N>x<K>.<input_type> For FP8 with FP32 accumulation on RDNA4: llvm.amdgcn.wmma.f32.16x16x32.fp8 Supported FP8 formats (per AMD RDNA4 ISA Section 7.5): - FP8 E4M3 (float8_e4m3fn): 4 exp bits, 3 mantissa, ExpBias=7-8 - FP8 E5M2 (float8_e5m2): 5 exp bits, 2 mantissa, ExpBias=15 - BF8 E5M2: 5 exp bits, 2 mantissa, ExpBias=16 This enables FP8 quantized model inference on RDNA4 GPUs with native hardware acceleration, while maintaining compile-time guards to prevent unsupported operations on earlier RDNA generations.
1 parent ab86bc0 commit bbeb3bf

File tree

1 file changed

+63
-13
lines changed

1 file changed

+63
-13
lines changed

mojo/stdlib/stdlib/gpu/mma.mojo

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,40 @@ from utils.index import Index
4343

4444

4545
fn get_amd_fp8_dtype() -> DType:
46+
@parameter
47+
if _is_amd_rdna():
48+
49+
@parameter
50+
if _is_amd_rdna4():
51+
return DType.float8_e4m3fn
52+
else:
53+
constrained[
54+
False,
55+
(
56+
"FP8 operations require RDNA4 or newer. RDNA3 and earlier"
57+
" do not support native FP8."
58+
),
59+
]()
60+
return DType.float8_e4m3fn
4661
return DType.float8_e4m3fn if _cdna_4_or_newer() else DType.float8_e4m3fnuz
4762

4863

4964
fn get_amd_bf8_dtype() -> DType:
65+
@parameter
66+
if _is_amd_rdna():
67+
68+
@parameter
69+
if _is_amd_rdna4():
70+
return DType.float8_e5m2
71+
else:
72+
constrained[
73+
False,
74+
(
75+
"BF8 operations require RDNA4 or newer. RDNA3 and earlier"
76+
" do not support native BF8."
77+
),
78+
]()
79+
return DType.float8_e5m2
5080
return DType.float8_e5m2 if _cdna_4_or_newer() else DType.float8_e5m2fnuz
5181

5282

@@ -107,6 +137,14 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
107137
RDNA4 additional operations:
108138
- F32 = FP8 * FP8 + F32 (16x16x32 shape, native hardware support)
109139
140+
FP8 support by generation:
141+
- RDNA4: Native FP8/BF8 via llvm.amdgcn.wmma.f32.16x16x32.fp8
142+
- Supports E4M3 (float8_e4m3fn) and E5M2 (float8_e5m2) formats
143+
- Hardware V_DOT4 instructions for 4-element dot products
144+
- NEG must be zero for A/B matrices in WMMA operations
145+
- RDNA3: FP8 not supported (requires FP16 emulation)
146+
- RDNA1/2: FP8 not supported
147+
110148
Args:
111149
d: Output accumulator SIMD vector (modified in-place).
112150
a: First input matrix as SIMD vector.
@@ -153,19 +191,31 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
153191
else:
154192
_unsupported_mma_op(d, a, b, c)
155193
return ""
156-
elif a.dtype in [
157-
DType.float8_e4m3fn,
158-
DType.float8_e4m3fnuz,
159-
DType.float8_e5m2,
160-
DType.float8_e5m2fnuz,
161-
] or b.dtype in [
162-
DType.float8_e4m3fn,
163-
DType.float8_e4m3fnuz,
164-
DType.float8_e5m2,
165-
DType.float8_e5m2fnuz,
166-
]:
167-
_unsupported_mma_op(d, a, b, c)
168-
return ""
194+
elif (
195+
a.dtype
196+
in [
197+
DType.float8_e4m3fn,
198+
DType.float8_e4m3fnuz,
199+
DType.float8_e5m2,
200+
DType.float8_e5m2fnuz,
201+
]
202+
and b.dtype
203+
in [
204+
DType.float8_e4m3fn,
205+
DType.float8_e4m3fnuz,
206+
DType.float8_e5m2,
207+
DType.float8_e5m2fnuz,
208+
]
209+
and c.dtype is DType.float32
210+
and d.dtype is DType.float32
211+
):
212+
213+
@parameter
214+
if _is_amd_rdna4():
215+
return "llvm.amdgcn.wmma.f32.16x16x32.fp8"
216+
else:
217+
_unsupported_mma_op(d, a, b, c)
218+
return ""
169219
else:
170220
_unsupported_mma_op(d, a, b, c)
171221
return ""

0 commit comments

Comments
 (0)