@@ -45,38 +45,14 @@ from utils.index import Index
4545fn get_amd_fp8_dtype () -> DType:
4646 @parameter
4747 if _is_amd_rdna():
48-
49- @parameter
50- if _is_amd_rdna4():
51- return DType.float8_e4m3fn
52- else :
53- constrained[
54- False ,
55- (
56- " FP8 operations require RDNA4 or newer. RDNA3 and earlier"
57- " do not support native FP8."
58- ),
59- ]()
60- return DType.float8_e4m3fn
48+ return DType.float8_e4m3fn
6149 return DType.float8_e4m3fn if _cdna_4_or_newer() else DType.float8_e4m3fnuz
6250
6351
6452fn get_amd_bf8_dtype () -> DType:
6553 @parameter
6654 if _is_amd_rdna():
67-
68- @parameter
69- if _is_amd_rdna4():
70- return DType.float8_e5m2
71- else :
72- constrained[
73- False ,
74- (
75- " BF8 operations require RDNA4 or newer. RDNA3 and earlier"
76- " do not support native BF8."
77- ),
78- ]()
79- return DType.float8_e5m2
55+ return DType.float8_e5m2
8056 return DType.float8_e5m2 if _cdna_4_or_newer() else DType.float8_e5m2fnuz
8157
8258
@@ -142,7 +118,16 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
142118 - Supports E4M3 (float8_e4m3fn) and E5M2 (float8_e5m2) formats
143119 - Hardware V_DOT4 instructions for 4-element dot products
144120 - NEG must be zero for A/B matrices in WMMA operations
145- - RDNA3: FP8 not supported (requires FP16 emulation)
121+ - RDNA3: FP8 is supported through FP16 emulation:
122+ - FP8 emulation via FP16 or BF16 conversion on RDNA3
123+ - We use BF16 for e5m2 variants (wider exponent range),
124+ FP16 for e4m3 variants
125+ - We support size 16 (split into 4x size 4) and size 4
126+ - For size 8, we need to split into two WMMA operations
127+ - This is a packed format where we do 2x 16x16x16 operations
128+ - But the intrinsic itself only supports size 4
129+ - For size 16 FP8 - will split into 4x size 4 operations
130+ - We convert each FP8 chunk to target dtype and run WMMA
146131 - RDNA1/2: FP8 not supported
147132
148133 Args:
@@ -216,10 +201,68 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
216201 else :
217202 _unsupported_mma_op(d, a, b, c)
218203 return " "
204+ elif (
205+ _is_amd_rdna3()
206+ and a.dtype.is_float8()
207+ and b.dtype.is_float8()
208+ and c.dtype is DType.float32
209+ and d.dtype is DType.float32
210+ ):
211+
212+ @parameter
213+ if _has_shape[4 ](a.size, b.size, c.size, d.size):
214+ return " llvm.amdgcn.wmma.f32.16x16x16.f16"
215+ elif (
216+ a.size == 16 and b.size == 16 and c.size == 32 and d.size == 32
217+ ):
218+ return " llvm.amdgcn.wmma.f32.16x16x16.f16"
219+ else :
220+ _unsupported_mma_op(d, a, b, c)
221+ return " "
219222 else :
220223 _unsupported_mma_op(d, a, b, c)
221224 return " "
222225
226+ @parameter
227+ if _is_amd_rdna3() and a.dtype.is_float8():
228+ alias target_dtype = DType.bfloat16 if (
229+ a.dtype is DType.float8_e5m2 or a.dtype is DType.float8_e5m2fnuz
230+ ) else DType.float16
231+ alias intrinsic_suffix = " bf16" if (
232+ a.dtype is DType.float8_e5m2 or a.dtype is DType.float8_e5m2fnuz
233+ ) else " f16"
234+
235+ @parameter
236+ if a.size == 16 and b.size == 16 :
237+ var result = c.cast[DType.float32]()
238+ alias intrinsic_name = " llvm.amdgcn.wmma.f32.16x16x16." + intrinsic_suffix
239+
240+ @parameter
241+ for i in range (4 ):
242+ alias offset = i * 4
243+ var a_chunk = a.slice[4 , offset=offset]()
244+ var b_chunk = b.slice[4 , offset=offset]()
245+ var c_chunk = result.slice[4 , offset=offset]()
246+ var a_converted = a_chunk.cast[target_dtype]()
247+ var b_converted = b_chunk.cast[target_dtype]()
248+ var r_chunk = llvm_intrinsic[
249+ intrinsic_name, SIMD [DType.float32, 4 ]
250+ ](a_converted, b_converted, c_chunk)
251+ result = result.insert[offset=offset](r_chunk)
252+
253+ d = rebind[__type_of(d)](result)
254+ return
255+ elif a.size == 4 and b.size == 4 :
256+ var a_converted = a.cast[target_dtype]()
257+ var b_converted = b.cast[target_dtype]()
258+
259+ var r = llvm_intrinsic[
260+ " llvm.amdgcn.wmma.f32.16x16x16." + intrinsic_suffix,
261+ SIMD [c.dtype, c.size],
262+ ](a_converted, b_converted, c)
263+ d = rebind[__type_of(d)](r)
264+ return
265+
223266 @parameter
224267 if a.size == 8 and b.size == 8 and c.size == 8 and d.size == 8 :
225268 alias intrinsic_name = get_intrinsic_name()
0 commit comments