Skip to content

Commit 1e36cbe

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Raise a NotImplementedError if swizzle=16.
Unswizzled MMAs don't lower correctly, and are not currently intended to be supported. PiperOrigin-RevId: 737981373
1 parent 8da9324 commit 1e36cbe

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

jax/experimental/mosaic/gpu/tcgen05.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def mma(
8383
accumulate: ir.Value | bool = True,
8484
collective: bool = False,
8585
):
86+
if a_swizzle == 16 or b_swizzle == 16:
87+
raise NotImplementedError("No swizzle is not supported")
8688
i32 = ir.IntegerType.get_signless(32)
8789
i64 = ir.IntegerType.get_signless(64)
8890
if isinstance(accumulate, bool):

jax/experimental/mosaic/gpu/wgmma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ def wgmma(
259259
The refs must be contiguous or be contiguous except for having their two minor
260260
dimensions swapped.
261261
"""
262+
if swizzle == 16:
263+
raise NotImplementedError("No swizzle is not supported")
262264
# Step 1. Establish the shape and element type of the operation.
263265
if not ir.MemRefType.isinstance(b.type):
264266
raise ValueError(f"B must be a memref, got: {b.type}")

0 commit comments

Comments
 (0)