@@ -109,6 +109,8 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
109109 RDNA3+ (all operations):
110110 - F32 = F16 * F16 + F32 (16x16x16 shape)
111111 - F32 = BF16 * BF16 + F32 (16x16x16 shape)
112+ - I32 = I8/U8 * I8/U8 + I32 (16x16x16 shape)
113+ - I32 = U4 * U4 + I32 (16x16x16 shape, using DType._uint4)
112114
113115 RDNA4 additional operations:
114116 - F32 = FP8 * FP8 + F32 (16x16x32 shape, native hardware support)
@@ -128,7 +130,21 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
128130 - But the intrinsic itself only supports size 4
129131 - For size 16 FP8 - will split into 4x size 4 operations
130132 - We convert each FP8 chunk to target dtype and run WMMA
131- - RDNA1/2: FP8 not supported
133+ - RDNA1/2: NO WMMA support (WMMA introduced in RDNA3)
134+
135+ RDNA3 FP8 emulation details:
136+ - FP8 operations emulated via FP16/BF16 conversion
137+ - BF16 used for E5M2 variants (wider exponent range)
138+ - FP16 used for E4M3 variants (more mantissa precision)
139+ - Size 16 operations split into 4x size 4 WMMA calls
140+ - Size 4 operations converted and executed directly
141+
142+ Hardware intrinsics used:
143+ - llvm.amdgcn.wmma.f32.16x16x16.f16 (FP16)
144+ - llvm.amdgcn.wmma.f32.16x16x16.bf16 (BF16)
145+ - llvm.amdgcn.wmma.i32.16x16x16.iu8 (INT8/UINT8, RDNA3+)
146+ - llvm.amdgcn.wmma.i32.16x16x16.iu4 (UINT4, RDNA3+)
147+ - llvm.amdgcn.wmma.f32.16x16x32.fp8 (FP8, RDNA4 only)
132148
133149 Args:
134150 d: Output accumulator SIMD vector (modified in-place).
@@ -143,6 +159,11 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
143159 is performed by get_intrinsic_name() which calls _unsupported_mma_op()
144160 for invalid combinations.
145161
162+ For quantized integer operations (int8/uint8/uint4), inputs are bitcast to
163+ int32 before passing to WMMA intrinsics.
164+
165+ FP8 operations on RDNA4 require NEG=0 for A/B matrices (hardware constraint).
166+
146167 References:
147168 - RDNA3 WMMA: https://gpuopen.com/learn/wmma_on_rdna3/
148169 - RDNA3 ISA: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf
@@ -219,6 +240,44 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
219240 else :
220241 _unsupported_mma_op(d, a, b, c)
221242 return " "
243+ elif (
244+ (a.dtype is DType.int8 or a.dtype is DType.uint8)
245+ and (b.dtype is DType.int8 or b.dtype is DType.uint8)
246+ and c.dtype is DType.int32
247+ and d.dtype is DType.int32
248+ ):
249+
250+ @parameter
251+ if _is_amd_rdna3() or _is_amd_rdna4():
252+
253+ @parameter
254+ if _has_shape[4 ](a.size, b.size, c.size, d.size):
255+ return " llvm.amdgcn.wmma.i32.16x16x16.iu8"
256+ else :
257+ _unsupported_mma_op(d, a, b, c)
258+ return " "
259+ else :
260+ _unsupported_mma_op(d, a, b, c)
261+ return " "
262+ elif (
263+ a.dtype is DType._uint4
264+ and b.dtype is DType._uint4
265+ and c.dtype is DType.int32
266+ and d.dtype is DType.int32
267+ ):
268+
269+ @parameter
270+ if _is_amd_rdna3() or _is_amd_rdna4():
271+
272+ @parameter
273+ if _has_shape[4 ](a.size, b.size, c.size, d.size):
274+ return " llvm.amdgcn.wmma.i32.16x16x16.iu4"
275+ else :
276+ _unsupported_mma_op(d, a, b, c)
277+ return " "
278+ else :
279+ _unsupported_mma_op(d, a, b, c)
280+ return " "
222281 else :
223282 _unsupported_mma_op(d, a, b, c)
224283 return " "
@@ -285,10 +344,21 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
285344
286345 d = rebind[__type_of(d)](result)
287346 else :
288- var r = llvm_intrinsic[get_intrinsic_name(), SIMD [c.dtype, c.size]](
289- a, b, c
290- )
291- d = rebind[__type_of(d)](r)
347+
348+ @parameter
349+ if (
350+ a.dtype is DType.int8 or a.dtype is DType.uint8
351+ ) and c.dtype is DType.int32:
352+ # Cast inputs to int32 for WMMA intrinsic
353+ var r = llvm_intrinsic[get_intrinsic_name(), SIMD [c.dtype, c.size]](
354+ bitcast[DType.int32, 1 ](a), bitcast[DType.int32, 1 ](b), c
355+ )
356+ d = rebind[__type_of(d)](r)
357+ else :
358+ var r = llvm_intrinsic[get_intrinsic_name(), SIMD [c.dtype, c.size]](
359+ a, b, c
360+ )
361+ d = rebind[__type_of(d)](r)
292362
293363
294364@fieldwise_init
0 commit comments