Skip to content

Commit 558fe27

Browse files
committed
Add RDNA3+ quantized integer tensor core support
Add tensor core shape support for integer quantization on RDNA3+ GPUs. RDNA3 introduced WMMA instructions that enable efficient quantized model inference. This adds kernel-layer support for the quantized operations enabled by previous stdlib commits. Supported quantized operations (RDNA3+): - INT8/UINT8 with INT32 accumulation (16x16x16 shape) - UINT4 with INT32 accumulation (16x16x16 shape) Implementation adds shape definitions in get_mma_shape() to route quantized dtypes to correct WMMA intrinsics. No changes to FP8 paths - FP8 support for RDNA4+ will be added separately once proper loading code exists.
1 parent 0891815 commit 558fe27

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

max/kernels/src/layout/tensor_core.mojo

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,12 @@ fn get_mma_shape[
14211421
return shape_16x16x16
14221422
elif accum_type is DType.float32 and input_type.is_float8():
14231423
return shape_16x16x32
1424+
elif accum_type is DType.int32 and (
1425+
input_type is DType.int8 or input_type is DType.uint8
1426+
):
1427+
return shape_16x16x16
1428+
elif accum_type is DType.int32 and (input_type is DType._uint4):
1429+
return shape_16x16x16
14241430
else:
14251431
constrained[False, "Unsupported RDNA mma shape."]()
14261432
return shape_null

0 commit comments

Comments
 (0)