Skip to content

Commit d4bd257

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts
PiperOrigin-RevId: 737956598
1 parent 38d52a1 commit d4bd257

File tree

2 files changed

+190
-77
lines changed

2 files changed

+190
-77
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 138 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,18 @@ def linear_thread_idxs(self):
526526
lane_dims=(-4, -2, -3),
527527
vector_dim=-1,
528528
)
529+
# This layout should be used when upcasting 4-bit elements to 16-bit, for the
530+
# purpose of passing them into WGMMA later. The core matrices stored by a warp
531+
# are 8x32, because each of the 4 threads in a row holds 8 elements in a single
532+
# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
533+
# group of 4 threads in order (as opposed to the swapping between 1 and 2,
534+
# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
535+
WGMMA_LAYOUT_UPCAST_4X = TiledLayout(
536+
Tiling(((64, 32), (16, 32), (8, 32), (8,))),
537+
warp_dim=-7,
538+
lane_dims=(-3, -2),
539+
vector_dim=-1,
540+
)
529541
# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
530542
# submatrix in the following way (we only show the first 4 rows for brevity):
531543
#
@@ -739,58 +751,132 @@ def to_layout(self, new_layout: FragmentedLayout):
739751
_layout=new_layout,
740752
_is_signed=self.is_signed,
741753
)
742-
if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 16 == 0:
743-
if (
744-
self.layout == WGMMA_LAYOUT_UPCAST_2X
745-
and new_layout == WGMMA_LAYOUT
746-
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) in {8, 16}
747-
):
748-
assert shape[1] % 16 == 0 # Should be implied by the layout
749-
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
750-
is_even = arith.cmpi(
751-
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
754+
if (
755+
self.layout == WGMMA_LAYOUT_UPCAST_2X
756+
and new_layout == WGMMA_LAYOUT
757+
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16
758+
):
759+
assert shape[1] % 16 == 0 # Should be implied by the layout
760+
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
761+
is_even = arith.cmpi(
762+
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
763+
)
764+
registers = self.registers
765+
if dtype_bitwidth == 4:
766+
if registers.shape[1] % 2:
767+
raise NotImplementedError(
768+
"This relayout implementation requires an even number of column"
769+
" tiles (to pack pairs of them for efficiency)"
770+
)
771+
# We pair up the consecutive column tiles, so each register is 32-bit.
772+
# If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
773+
# LLVM will realize that the paired up vectors actually came from the
774+
# same 32-bit register and it will become a no-op.
775+
col_minor_registers = np.moveaxis(registers, 1, -1)
776+
flat_registers = [
777+
utils.vector_concat((l, h))
778+
for l, h in zip(
779+
col_minor_registers.flat[::2], col_minor_registers.flat[1::2]
780+
)
781+
]
782+
registers = np.asarray(flat_registers, dtype=object).reshape(
783+
*col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2
752784
)
753-
for idx, reg in np.ndenumerate(self.registers):
754-
assert ir.VectorType(reg.type).shape == [4]
755-
if dtype_bitwidth == 16:
756-
# A single vector is 64-bits, but shuffles are only 32-bit wide.
757-
# We only shuffle the half that needs to go to other thread.
758-
low = utils.vector_slice(reg, slice(0, 2))
759-
high = utils.vector_slice(reg, slice(2, 4))
760-
to_exchange = arith.select(is_even, high, low)
761-
# Exchange values between even and odd threads.
762-
exchanged = utils.shfl_bfly(to_exchange, 1)
763-
low = arith.select(is_even, low, exchanged)
764-
high = arith.select(is_even, exchanged, high)
765-
elif dtype_bitwidth == 8:
766-
# The vector is 32-bits, so we just shuffle the whole thing and
767-
# use prmt to blend it with the local register.
768-
exchanged = utils.shfl_bfly(reg, 1)
769-
# Consider lanes 0 and 1, because the situation is symmetric for
770-
# each pair. If we feed reg[lane] and exchanged[lane] (which is
771-
# really the same as reg of the other lane) to prmt, we can index
772-
# the elements of the result using the following indices:
773-
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
774-
# prmt[0]: 0 1 2 3 4 5 6 7
775-
# prmt[1]: 4 5 6 7 0 1 2 3
776-
# The expected outputs and their respective permutations are:
777-
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
778-
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
779-
# Note that the patterns still need to be flipped, since we listed
780-
# bytes with LSB on the left, which is the opposite of how the
781-
# numeric constants are spelled in Python (LSB on the right).
782-
perm = arith.select(is_even, c(0x5410), c(0x3276))
783-
blend = utils.prmt(reg, exchanged, perm)
784-
low = utils.vector_slice(blend, slice(0, 2))
785-
high = utils.vector_slice(blend, slice(2, 4))
786-
else:
787-
raise NotImplementedError(dtype_bitwidth)
785+
registers = np.moveaxis(registers, -1, 1)
786+
for idx, reg in np.ndenumerate(registers):
787+
if dtype_bitwidth == 16:
788+
assert reg.type.shape == [4]
789+
# A single vector is 64-bits, but shuffles are only 32-bit wide.
790+
# We only shuffle the half that needs to go to other thread.
791+
low = utils.vector_slice(reg, slice(0, 2))
792+
high = utils.vector_slice(reg, slice(2, 4))
793+
to_exchange = arith.select(is_even, high, low)
794+
# Exchange values between even and odd threads.
795+
exchanged = utils.shfl_bfly(to_exchange, 1)
796+
low = arith.select(is_even, low, exchanged)
797+
high = arith.select(is_even, exchanged, high)
788798
new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low
789799
new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high
790-
assert all(r is not None for r in new_registers)
791-
return FragmentedArray(
792-
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
793-
)
800+
elif dtype_bitwidth == 8:
801+
assert reg.type.shape == [4]
802+
# The vector is 32-bits, so we just shuffle the whole thing and
803+
# use prmt to blend it with the local register.
804+
exchanged = utils.shfl_bfly(reg, 1)
805+
# Consider lanes 0 and 1, because the situation is symmetric for
806+
# each pair. If we feed reg[lane] and exchanged[lane] (which is
807+
# really the same as reg of the other lane) to prmt, we can index
808+
# the elements of the result using the following indices:
809+
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
810+
# prmt[0]: 0 1 2 3 4 5 6 7
811+
# prmt[1]: 4 5 6 7 0 1 2 3
812+
# The expected outputs and their respective permutations are:
813+
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
814+
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
815+
# Note that the patterns still need to be flipped, since we listed
816+
# bytes with LSB on the left, which is the opposite of how the
817+
# numeric constants are spelled in Python (LSB on the right).
818+
perm = arith.select(is_even, c(0x5410), c(0x3276))
819+
blend = utils.prmt(reg, exchanged, perm)
820+
for i in range(2):
821+
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
822+
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
823+
else:
824+
assert dtype_bitwidth == 4
825+
assert reg.type.shape == [8] # We paired up the registers above.
826+
exchanged = utils.shfl_bfly(reg, 1)
827+
# See comment above for a more complete explanation.
828+
# reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
829+
# prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
830+
# prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
831+
# The expected outputs and their respective permutations are:
832+
# out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
833+
# prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
834+
perm = arith.select(is_even, c(0x6240), c(0x3715))
835+
blend = utils.prmt(reg, exchanged, perm)
836+
for i in range(4):
837+
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
838+
new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg
839+
assert all(r is not None for r in new_registers)
840+
return FragmentedArray(
841+
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
842+
)
843+
if (
844+
self.layout == WGMMA_LAYOUT_UPCAST_4X
845+
and new_layout == WGMMA_LAYOUT_UPCAST_2X
846+
and utils.bitwidth(self.mlir_dtype) == 4
847+
):
848+
assert shape[0] % 64 == 0 # Should be implied by the layout
849+
assert shape[1] % 32 == 0 # Should be implied by the layout
850+
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
851+
i32 = ir.IntegerType.get_signless(32)
852+
c = lambda x: arith.constant(i32, x)
853+
is_01 = arith.cmpi(
854+
arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2)
855+
)
856+
for idx, reg in np.ndenumerate(self.registers):
857+
assert ir.VectorType(reg.type).shape == [8]
858+
# The vector is 32-bits, so we just shuffle the whole thing and
859+
# use prmt to blend it with the local register.
860+
exchanged = utils.shfl_bfly(reg, 2)
861+
# See comments above for conventions. Here we exchange data between
862+
# threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
863+
# reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
864+
# prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
865+
# prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
866+
# The expected outputs and their respective permutations are:
867+
# out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
868+
# prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
869+
perm = arith.select(is_01, c(0x5410), c(0x3276))
870+
blend = utils.prmt(reg, exchanged, perm)
871+
for i in range(2):
872+
reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4))
873+
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
874+
assert all(r is not None for r in new_registers)
875+
return FragmentedArray(
876+
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
877+
)
878+
if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT:
879+
return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout)
794880
if not isinstance(self.layout, WGSplatFragLayout):
795881
raise NotImplementedError(
796882
f"Cannot convert from {self.layout} to {new_layout}"
@@ -1288,7 +1374,9 @@ def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
12881374
int_ty = ir.IntegerType.get_signless(group_size * 4)
12891375
while vector_len - offset >= group_size:
12901376
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
1291-
reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty))
1377+
reg_slice_int = utils.bitcast(reg_slice, int_ty)
1378+
if int_ty != i32:
1379+
reg_slice_int = arith.extsi(i32, reg_slice_int)
12921380
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
12931381
out_int_regs.extend(
12941382
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)

tests/mosaic/gpu_test.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515

1616
from collections.abc import Sequence
17+
import contextlib
1718
import dataclasses
1819
import enum
1920
import itertools
@@ -83,6 +84,20 @@ def mlir_sum(elems):
8384
return total
8485

8586

87+
@contextlib.contextmanager
88+
def get_sass():
89+
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
90+
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
91+
try:
92+
with jtu.capture_stdout() as output:
93+
yield output
94+
finally:
95+
if prev_dump is not None:
96+
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
97+
else:
98+
del os.environ["MOSAIC_GPU_DUMP_SASS"]
99+
100+
86101
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
87102
index = ir.IndexType.get()
88103
thread_id = gpu.thread_id(gpu.Dimension.x)
@@ -542,7 +557,11 @@ def kernel(ctx, inp, out, smem):
542557
(jnp.int8, jnp.bfloat16),
543558
(jnp.int4, jnp.bfloat16),
544559
),
545-
layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X),
560+
layout=(
561+
fa.WGMMA_LAYOUT,
562+
fa.WGMMA_LAYOUT_UPCAST_2X,
563+
fa.WGMMA_LAYOUT_UPCAST_4X,
564+
),
546565
)
547566
def test_optimized_conversion(self, jax_dtype_from_to, layout):
548567
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
@@ -2194,19 +2213,11 @@ def kernel(ctx, in_, out, smems):
21942213
.transpose(0, 2, 1, 3)
21952214
)
21962215

2197-
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
2198-
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
2199-
try:
2200-
with jtu.capture_stdout() as get_sass:
2201-
iota = mgpu.as_gpu_kernel(
2202-
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
2203-
[expected, expected, mgpu.TMABarrier()],
2204-
)(expected)
2205-
finally:
2206-
if prev_dump is not None:
2207-
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
2208-
else:
2209-
del os.environ["MOSAIC_GPU_DUMP_SASS"]
2216+
with get_sass() as sass:
2217+
iota = mgpu.as_gpu_kernel(
2218+
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
2219+
[expected, expected, mgpu.TMABarrier()],
2220+
)(expected)
22102221
np.testing.assert_array_equal(iota, expected)
22112222

22122223
# Verify that we don't use too many registers for the transfers.
@@ -2219,7 +2230,7 @@ def kernel(ctx, in_, out, smems):
22192230
expected_regs //= 2
22202231
for instr in ("STS", "LDS"):
22212232
with self.subTest(instr + " count"):
2222-
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
2233+
addrs = re.findall(instr + r".* \[(.*)\]", sass())
22232234
def get_reg(addr):
22242235
if (pos := addr.find("+")) != -1:
22252236
return addr[:pos]
@@ -2294,30 +2305,38 @@ def kernel(ctx, in_, out, smems):
22942305
)(x)
22952306
np.testing.assert_array_equal(y, y_ref)
22962307

2297-
@parameterized.product(
2298-
upcast_before_layout_change=[True, False],
2308+
@parameterized.parameters(
2309+
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1),
2310+
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1),
2311+
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1),
2312+
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5),
2313+
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
22992314
)
2300-
def test_upcast_to_wgmma(self, upcast_before_layout_change):
2301-
in_dtype = jnp.dtype(jnp.int8)
2315+
def test_upcast_to_wgmma(
2316+
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
2317+
):
2318+
in_dtype = jnp.dtype(in_dtype)
23022319
out_dtype = jnp.dtype(jnp.int16)
2320+
out_dtype_mlir = utils.dtype_to_ir_type(out_dtype)
23032321
swizzle = 128
23042322
in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits
23052323
in_tiling = (8, in_col_tiling)
23062324
out_col_tiling = swizzle // out_dtype.itemsize
23072325
out_tiling = (8, out_col_tiling)
23082326
m, n = 128, in_col_tiling * 2
2327+
regs_per_thread = None
23092328
def kernel(ctx, in_, out, smems):
2329+
nonlocal regs_per_thread
23102330
smem_in, smem_out, barrier = smems
23112331
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
23122332
barrier.wait()
23132333
t = mgpu.FragmentedArray.load_tiled(
2314-
smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT_UPCAST_2X
2334+
smem_in, swizzle=swizzle, is_signed=True, layout=start_layout
23152335
)
2316-
if upcast_before_layout_change:
2317-
t = t.astype(ir.IntegerType.get_signless(16), is_signed=True)
2318-
t = t.to_layout(fa.WGMMA_LAYOUT)
2319-
if not upcast_before_layout_change:
2320-
t = t.astype(ir.IntegerType.get_signless(16), is_signed=True)
2336+
regs_per_thread = t.registers.size
2337+
t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True)
2338+
t = t.to_layout(end_layout)
2339+
t = t.astype(out_dtype_mlir, is_signed=True)
23212340
t.store_tiled(smem_out, swizzle=swizzle)
23222341
mgpu.commit_shared()
23232342
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
@@ -2326,14 +2345,20 @@ def tile(x, tiling):
23262345
return x.reshape(
23272346
x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1]
23282347
).transpose(0, 2, 1, 3)
2329-
x = jax.random.randint(jax.random.key(42), (m, n), -128, 127, dtype=in_dtype)
2348+
in_iinfo = jnp.iinfo(in_dtype)
2349+
x = jax.random.randint(
2350+
jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32
2351+
).astype(in_dtype)
23302352
xt = tile(x, in_tiling)
23312353
y = x.astype(out_dtype)
23322354
yt = tile(y, out_tiling)
23332355
f = mgpu.as_gpu_kernel(
23342356
kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()],
23352357
)
2336-
np.testing.assert_array_equal(f(xt), yt)
2358+
with get_sass() as sass:
2359+
yt_kernel = f(xt)
2360+
np.testing.assert_array_equal(yt_kernel, yt)
2361+
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
23372362

23382363

23392364
@dataclasses.dataclass(frozen=True)

0 commit comments

Comments
 (0)