@@ -21,6 +21,8 @@ from sys.info import (
2121 CompilationTarget,
2222 _cdna_4_or_newer,
2323 _is_amd_rdna,
24+ _is_amd_rdna3,
25+ _is_amd_rdna4,
2426 is_amd_gpu,
2527)
2628
@@ -90,17 +92,39 @@ fn _has_shape[
9092
9193@always_inline
9294fn _mma_wmma_rdna (mut d : SIMD , a : SIMD , b : SIMD , c : SIMD ):
93- """ AMD RDNA3+ WMMA implementation for matrix multiplication.
94-
95- RDNA3/4 GPUs use WMMA instructions.
96- Per https://gpuopen.com/learn/wmma_on_rdna3/
97- the following intrinsics are supported:
98- - llvm.amdgcn.wmma.f32.16x16x16.f16
99- - llvm.amdgcn.wmma.f32.16x16x16.bf16
100- - llvm.amdgcn.wmma.f16.16x16x16.f16
101- - llvm.amdgcn.wmma.bf16.16x16x16.bf16
102- - llvm.amdgcn.wmma.i32.16x16x16.iu8
103- - llvm.amdgcn.wmma.i32.16x16x16.iu4
95+ """ Performs AMD RDNA3+ WMMA (Wave Matrix Multiply-Accumulate) operations.
96+
97+ This function implements matrix multiply-accumulate operations for AMD RDNA3+
98+ consumer GPUs using WMMA instructions. WMMA was introduced in RDNA3 and is not
99+ available on RDNA1/2 hardware.
100+
101+ Supported operations by RDNA generation:
102+
103+ RDNA3+ (all operations):
104+ - F32 = F16 * F16 + F32 (16x16x16 shape)
105+ - F32 = BF16 * BF16 + F32 (16x16x16 shape)
106+
107+ RDNA4 additional operations:
108+ - F32 = FP8 * FP8 + F32 (16x16x32 shape, native hardware support)
109+
110+ Args:
111+ d: Output accumulator SIMD vector (modified in-place).
112+ a: First input matrix as SIMD vector.
113+ b: Second input matrix as SIMD vector.
114+ c: Accumulator matrix as SIMD vector.
115+
116+ Note:
117+ RDNA WMMA supports both size 4 (single 16x16x16 operation) and size 8
118+ (double wave, split into two 16x16x16 operations). Size 8 for half-precision
119+ represents two operations packed together. Type and shape validation
120+ is performed by get_intrinsic_name() which calls _unsupported_mma_op()
121+ for invalid combinations.
122+
123+ References:
124+ - RDNA3 WMMA: https://gpuopen.com/learn/wmma_on_rdna3/
125+ - RDNA3 ISA: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf
126+ - RDNA4 ISA: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf
127+ - Section 7.5 (8-bit Math) for FP8/BF8 details
104128 """
105129
106130 @parameter
@@ -111,22 +135,67 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
111135 # F32 = BF16 * BF16 + F32 (16x16x16)
112136 # ===------------------------------------------------------------------===#
113137 @parameter
114- if (
115- _has_type[
116- (DType.float16, DType.float16, DType.float32, DType.float32)
117- ](a.dtype, b.dtype, c.dtype, d.dtype)
118- or _has_type[
119- (DType.bfloat16, DType.bfloat16, DType.float32, DType.float32)
120- ](a.dtype, b.dtype, c.dtype, d.dtype)
121- ) and _has_shape[4 ](a.size, b.size, c.size, d.size):
122- alias type_name = " f16" if a.dtype is DType.float16 else " bf16"
123- return " llvm.amdgcn.wmma.f32.16x16x16." + type_name
138+ if _has_type[
139+ (DType.float16, DType.float16, DType.float32, DType.float32)
140+ ](a.dtype, b.dtype, c.dtype, d.dtype) or _has_type[
141+ (DType.bfloat16, DType.bfloat16, DType.float32, DType.float32)
142+ ](
143+ a.dtype, b.dtype, c.dtype, d.dtype
144+ ):
145+
146+ @parameter
147+ if _has_shape[4 ](a.size, b.size, c.size, d.size):
148+ alias type_name = " f16" if a.dtype is DType.float16 else " bf16"
149+ return " llvm.amdgcn.wmma.f32.16x16x16." + type_name
150+ elif _has_shape[(8 , 8 , 32 , 32 )](a.size, b.size, c.size, d.size):
151+ alias type_name = " f16" if a.dtype is DType.float16 else " bf16"
152+ return " llvm.amdgcn.wmma.f32.16x16x16." + type_name
153+ else :
154+ _unsupported_mma_op(d, a, b, c)
155+ return " "
156+ elif a.dtype in [
157+ DType.float8_e4m3fn,
158+ DType.float8_e4m3fnuz,
159+ DType.float8_e5m2,
160+ DType.float8_e5m2fnuz,
161+ ] or b.dtype in [
162+ DType.float8_e4m3fn,
163+ DType.float8_e4m3fnuz,
164+ DType.float8_e5m2,
165+ DType.float8_e5m2fnuz,
166+ ]:
167+ _unsupported_mma_op(d, a, b, c)
168+ return " "
124169 else :
125170 _unsupported_mma_op(d, a, b, c)
126171 return " "
127172
128- var r = llvm_intrinsic[get_intrinsic_name(), SIMD [c.dtype, c.size]](a, b, c)
129- d = rebind[type_of(d)](r)
173+ @parameter
174+ if a.size == 8 and b.size == 8 and c.size == 8 and d.size == 8 :
175+ alias intrinsic_name = get_intrinsic_name()
176+ var r = llvm_intrinsic[intrinsic_name, SIMD [c.dtype, 8 ]](a, b, c)
177+ d = rebind[__type_of(d)](r)
178+ elif a.size == 8 and b.size == 8 and c.size == 32 and d.size == 32 :
179+ var result = c.copy()
180+ alias intrinsic_name = get_intrinsic_name()
181+
182+ @parameter
183+ for i in range (2 ):
184+ alias offset = i * 4
185+ var a_slice = a.slice[4 , offset=offset]()
186+ var b_slice = b.slice[4 , offset=offset]()
187+ var c_slice = c.slice[4 , offset=offset]()
188+ var r = llvm_intrinsic[intrinsic_name, SIMD [c.dtype, 4 ]](
189+ a_slice, b_slice, c_slice
190+ )
191+ result = result.insert[offset=offset](r)
192+
193+ d = rebind[__type_of(d)](result)
194+ else :
195+ var r = llvm_intrinsic[get_intrinsic_name(), SIMD [c.dtype, c.size]](
196+ a, b, c
197+ )
198+ d = rebind[__type_of(d)](r)
130199
131200
132201@fieldwise_init
0 commit comments