Skip to content

Commit abffa82

Browse files
committed
Emulate FP8 on RDNA3 with optimized BF16/FP16 conversion
RDNA3 lacks native FP8 tensor cores. Emulate by converting FP8→FP16/BF16 and splitting size 16 operations into 4x size 4 WMMA calls. This enables FP8 quantized models to run on RDNA3 hardware with precision fallback for optimal performance. Remove constraints in get_amd_fp8_dtype() and get_amd_bf8_dtype() that blocked RDNA3 from using FP8 dtypes now that we have proper emulation.
1 parent bbeb3bf commit abffa82

File tree

1 file changed

+70
-27
lines changed

1 file changed

+70
-27
lines changed

mojo/stdlib/stdlib/gpu/mma.mojo

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,38 +45,14 @@ from utils.index import Index
4545
fn get_amd_fp8_dtype() -> DType:
4646
@parameter
4747
if _is_amd_rdna():
48-
49-
@parameter
50-
if _is_amd_rdna4():
51-
return DType.float8_e4m3fn
52-
else:
53-
constrained[
54-
False,
55-
(
56-
"FP8 operations require RDNA4 or newer. RDNA3 and earlier"
57-
" do not support native FP8."
58-
),
59-
]()
60-
return DType.float8_e4m3fn
48+
return DType.float8_e4m3fn
6149
return DType.float8_e4m3fn if _cdna_4_or_newer() else DType.float8_e4m3fnuz
6250

6351

6452
fn get_amd_bf8_dtype() -> DType:
6553
@parameter
6654
if _is_amd_rdna():
67-
68-
@parameter
69-
if _is_amd_rdna4():
70-
return DType.float8_e5m2
71-
else:
72-
constrained[
73-
False,
74-
(
75-
"BF8 operations require RDNA4 or newer. RDNA3 and earlier"
76-
" do not support native BF8."
77-
),
78-
]()
79-
return DType.float8_e5m2
55+
return DType.float8_e5m2
8056
return DType.float8_e5m2 if _cdna_4_or_newer() else DType.float8_e5m2fnuz
8157

8258

@@ -142,7 +118,16 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
142118
- Supports E4M3 (float8_e4m3fn) and E5M2 (float8_e5m2) formats
143119
- Hardware V_DOT4 instructions for 4-element dot products
144120
- NEG must be zero for A/B matrices in WMMA operations
145-
- RDNA3: FP8 not supported (requires FP16 emulation)
121+
- RDNA3: FP8 is supported through FP16 emulation:
122+
- FP8 emulation via FP16 or BF16 conversion on RDNA3
123+
- We use BF16 for e5m2 variants (wider exponent range),
124+
FP16 for e4m3 variants
125+
- We support size 16 (split into 4x size 4) and size 4
126+
- For size 8, we need to split into two WMMA operations
127+
- This is a packed format where we do 2x 16x16x16 operations
128+
- But the intrinsic itself only supports size 4
129+
- For size 16 FP8 - will split into 4x size 4 operations
130+
- We convert each FP8 chunk to target dtype and run WMMA
146131
- RDNA1/2: FP8 not supported
147132
148133
Args:
@@ -216,10 +201,68 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
216201
else:
217202
_unsupported_mma_op(d, a, b, c)
218203
return ""
204+
elif (
205+
_is_amd_rdna3()
206+
and a.dtype.is_float8()
207+
and b.dtype.is_float8()
208+
and c.dtype is DType.float32
209+
and d.dtype is DType.float32
210+
):
211+
212+
@parameter
213+
if _has_shape[4](a.size, b.size, c.size, d.size):
214+
return "llvm.amdgcn.wmma.f32.16x16x16.f16"
215+
elif (
216+
a.size == 16 and b.size == 16 and c.size == 32 and d.size == 32
217+
):
218+
return "llvm.amdgcn.wmma.f32.16x16x16.f16"
219+
else:
220+
_unsupported_mma_op(d, a, b, c)
221+
return ""
219222
else:
220223
_unsupported_mma_op(d, a, b, c)
221224
return ""
222225

226+
@parameter
227+
if _is_amd_rdna3() and a.dtype.is_float8():
228+
alias target_dtype = DType.bfloat16 if (
229+
a.dtype is DType.float8_e5m2 or a.dtype is DType.float8_e5m2fnuz
230+
) else DType.float16
231+
alias intrinsic_suffix = "bf16" if (
232+
a.dtype is DType.float8_e5m2 or a.dtype is DType.float8_e5m2fnuz
233+
) else "f16"
234+
235+
@parameter
236+
if a.size == 16 and b.size == 16:
237+
var result = c.cast[DType.float32]()
238+
alias intrinsic_name = "llvm.amdgcn.wmma.f32.16x16x16." + intrinsic_suffix
239+
240+
@parameter
241+
for i in range(4):
242+
alias offset = i * 4
243+
var a_chunk = a.slice[4, offset=offset]()
244+
var b_chunk = b.slice[4, offset=offset]()
245+
var c_chunk = result.slice[4, offset=offset]()
246+
var a_converted = a_chunk.cast[target_dtype]()
247+
var b_converted = b_chunk.cast[target_dtype]()
248+
var r_chunk = llvm_intrinsic[
249+
intrinsic_name, SIMD[DType.float32, 4]
250+
](a_converted, b_converted, c_chunk)
251+
result = result.insert[offset=offset](r_chunk)
252+
253+
d = rebind[__type_of(d)](result)
254+
return
255+
elif a.size == 4 and b.size == 4:
256+
var a_converted = a.cast[target_dtype]()
257+
var b_converted = b.cast[target_dtype]()
258+
259+
var r = llvm_intrinsic[
260+
"llvm.amdgcn.wmma.f32.16x16x16." + intrinsic_suffix,
261+
SIMD[c.dtype, c.size],
262+
](a_converted, b_converted, c)
263+
d = rebind[__type_of(d)](r)
264+
return
265+
223266
@parameter
224267
if a.size == 8 and b.size == 8 and c.size == 8 and d.size == 8:
225268
alias intrinsic_name = get_intrinsic_name()

0 commit comments

Comments
 (0)