Skip to content

Commit 0891815

Browse files
committed
Merge duplicate RDNA WMMA intrinsic branches
Merge three identical branches that all returned the same intrinsic "llvm.amdgcn.wmma.f32.16x16x16.{f16,bf16}" for different SIMD sizes. All three shapes (4, (8,8,8,8), and (8,8,32,32)) use the same 16x16x16 intrinsic, so consolidate into a single branch with combined comments explaining what each size represents.
1 parent e410539 commit 0891815

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

mojo/stdlib/stdlib/gpu/mma.mojo

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
115115
RDNA4 additional operations:
116116
- F32 = FP8 * FP8 + F32 (16x16x32 shape, native hardware support)
117117
118+
RDNA WMMA supports multiple sizes:
119+
- Size 4: single wave
120+
- Size (8,8,8,8): Wave32 mode with 8 accumulator registers
121+
- Size (8,8,32,32): packed operations (split into multiple WMMA ops)
122+
- All use the same 16x16x16 intrinsic
123+
118124
FP8 support by generation:
119125
- RDNA4: Native FP8/BF8 via llvm.amdgcn.wmma.f32.16x16x32.fp8
120126
- Supports E4M3 (float8_e4m3fn) and E5M2 (float8_e5m2) formats
@@ -188,10 +194,11 @@ fn _mma_wmma_rdna(mut d: SIMD, a: SIMD, b: SIMD, c: SIMD):
188194
):
189195

190196
@parameter
191-
if _has_shape[4](a.size, b.size, c.size, d.size):
192-
alias type_name = "f16" if a.dtype is DType.float16 else "bf16"
193-
return "llvm.amdgcn.wmma.f32.16x16x16." + type_name
194-
elif _has_shape[(8, 8, 32, 32)](a.size, b.size, c.size, d.size):
197+
if (
198+
_has_shape[4](a.size, b.size, c.size, d.size)
199+
or _has_shape[(8, 8, 8, 8)](a.size, b.size, c.size, d.size)
200+
or _has_shape[(8, 8, 32, 32)](a.size, b.size, c.size, d.size)
201+
):
195202
alias type_name = "f16" if a.dtype is DType.float16 else "bf16"
196203
return "llvm.amdgcn.wmma.f32.16x16x16." + type_name
197204
else:

0 commit comments

Comments
 (0)