Skip to content

Commit e410539

Browse files
committed
Enable int4/int8 quantized models on RDNA3+
Enable int8/uint8/uint4 WMMA operations on RDNA3+ GPUs for quantized model inference. RDNA3 introduced WMMA instructions - RDNA1/2 lack WMMA support. Supported quantized operations (RDNA3+): - I32 = I8/U8 * I8/U8 + I32 (16x16x16 shape via iu8 intrinsic) - I32 = U4 * U4 + I32 (16x16x16 shape via iu4 intrinsic)
1 parent abffa82 commit e410539

File tree

1 file changed

+75
-5
lines changed

1 file changed

+75
-5
lines changed

mojo/stdlib/stdlib/gpu/mma.mojo

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
109109
RDNA3+ (all operations):
110110
- F32 = F16 * F16 + F32 (16x16x16 shape)
111111
- F32 = BF16 * BF16 + F32 (16x16x16 shape)
112+
- I32 = I8/U8 * I8/U8 + I32 (16x16x16 shape)
113+
- I32 = U4 * U4 + I32 (16x16x16 shape, using DType._uint4)
112114
113115
RDNA4 additional operations:
114116
- F32 = FP8 * FP8 + F32 (16x16x32 shape, native hardware support)
@@ -128,7 +130,21 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
128130
- But the intrinsic itself only supports size 4
129131
- For size 16 FP8 - will split into 4x size 4 operations
130132
- We convert each FP8 chunk to target dtype and run WMMA
131-
- RDNA1/2: FP8 not supported
133+
- RDNA1/2: NO WMMA support (WMMA introduced in RDNA3)
134+
135+
RDNA3 FP8 emulation details:
136+
- FP8 operations emulated via FP16/BF16 conversion
137+
- BF16 used for E5M2 variants (wider exponent range)
138+
- FP16 used for E4M3 variants (more mantissa precision)
139+
- Size 16 operations split into 4x size 4 WMMA calls
140+
- Size 4 operations converted and executed directly
141+
142+
Hardware intrinsics used:
143+
- llvm.amdgcn.wmma.f32.16x16x16.f16 (FP16)
144+
- llvm.amdgcn.wmma.f32.16x16x16.bf16 (BF16)
145+
- llvm.amdgcn.wmma.i32.16x16x16.iu8 (INT8/UINT8, RDNA3+)
146+
- llvm.amdgcn.wmma.i32.16x16x16.iu4 (UINT4, RDNA3+)
147+
- llvm.amdgcn.wmma.f32.16x16x32.fp8 (FP8, RDNA4 only)
132148
133149
Args:
134150
d: Output accumulator SIMD vector (modified in-place).
@@ -143,6 +159,11 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
143159
is performed by get_intrinsic_name() which calls _unsupported_mma_op()
144160
for invalid combinations.
145161
162+
For quantized integer operations (int8/uint8/uint4), inputs are bitcast to
163+
int32 before passing to WMMA intrinsics.
164+
165+
FP8 operations on RDNA4 require NEG=0 for A/B matrices (hardware constraint).
166+
146167
References:
147168
- RDNA3 WMMA: https://gpuopen.com/learn/wmma_on_rdna3/
148169
- RDNA3 ISA: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf
@@ -219,6 +240,44 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
219240
else:
220241
_unsupported_mma_op(d, a, b, c)
221242
return ""
243+
elif (
244+
(a.dtype is DType.int8 or a.dtype is DType.uint8)
245+
and (b.dtype is DType.int8 or b.dtype is DType.uint8)
246+
and c.dtype is DType.int32
247+
and d.dtype is DType.int32
248+
):
249+
250+
@parameter
251+
if _is_amd_rdna3() or _is_amd_rdna4():
252+
253+
@parameter
254+
if _has_shape[4](a.size, b.size, c.size, d.size):
255+
return "llvm.amdgcn.wmma.i32.16x16x16.iu8"
256+
else:
257+
_unsupported_mma_op(d, a, b, c)
258+
return ""
259+
else:
260+
_unsupported_mma_op(d, a, b, c)
261+
return ""
262+
elif (
263+
a.dtype is DType._uint4
264+
and b.dtype is DType._uint4
265+
and c.dtype is DType.int32
266+
and d.dtype is DType.int32
267+
):
268+
269+
@parameter
270+
if _is_amd_rdna3() or _is_amd_rdna4():
271+
272+
@parameter
273+
if _has_shape[4](a.size, b.size, c.size, d.size):
274+
return "llvm.amdgcn.wmma.i32.16x16x16.iu4"
275+
else:
276+
_unsupported_mma_op(d, a, b, c)
277+
return ""
278+
else:
279+
_unsupported_mma_op(d, a, b, c)
280+
return ""
222281
else:
223282
_unsupported_mma_op(d, a, b, c)
224283
return ""
@@ -285,10 +344,21 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
285344

286345
d = rebind[__type_of(d)](result)
287346
else:
288-
var r = llvm_intrinsic[get_intrinsic_name(), SIMD[c.dtype, c.size]](
289-
a, b, c
290-
)
291-
d = rebind[__type_of(d)](r)
347+
348+
@parameter
349+
if (
350+
a.dtype is DType.int8 or a.dtype is DType.uint8
351+
) and c.dtype is DType.int32:
352+
# Cast inputs to int32 for WMMA intrinsic
353+
var r = llvm_intrinsic[get_intrinsic_name(), SIMD[c.dtype, c.size]](
354+
bitcast[DType.int32, 1](a), bitcast[DType.int32, 1](b), c
355+
)
356+
d = rebind[__type_of(d)](r)
357+
else:
358+
var r = llvm_intrinsic[get_intrinsic_name(), SIMD[c.dtype, c.size]](
359+
a, b, c
360+
)
361+
d = rebind[__type_of(d)](r)
292362

293363

294364
@fieldwise_init

0 commit comments

Comments
 (0)