Skip to content

Commit 4d6f15f

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for slicing tiled refs with (tile aligned) dynamic base offsets
PiperOrigin-RevId: 738762062
1 parent 1c8e60e commit 4d6f15f

File tree

3 files changed

+93
-20
lines changed

3 files changed

+93
-20
lines changed

jax/_src/pallas/mosaic_gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ pytype_strict_library(
8080
"//jax:mosaic_gpu",
8181
"//jax:state_types",
8282
"//jax:tree_util",
83+
"//jax/_src/lib",
8384
"//jax/_src/pallas",
8485
"//jaxlib/mlir:ir",
8586
] + py_deps("numpy"),

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929
from jax._src import dtypes
3030
from jax._src import effects
3131
from jax._src import tree_util
32+
from jax._src.lib.mlir.dialects import arith as arith_dialect
3233
from jax._src.pallas import core as pallas_core
34+
from jax._src.state import discharge as state_discharge
3335
from jax._src.state import indexing
3436
from jax._src.state import types as state_types
35-
from jax._src.state import discharge as state_discharge
3637
import jax.experimental.mosaic.gpu as mgpu
3738
import jax.numpy as jnp
3839
from jaxlib.mlir import ir
@@ -135,6 +136,24 @@ def cmap_body():
135136
return wrapper
136137

137138

139+
def _is_known_divisible(value, divisor, fuel=10) -> bool:
140+
"""Returns True if the value is statically known to be divisible by the divisor."""
141+
if fuel < 0:
142+
return False
143+
if not isinstance(value.owner, ir.Operation):
144+
return False
145+
def_op = value.owner.opview
146+
match def_op:
147+
case arith_dialect.IndexCastOp():
148+
return _is_known_divisible(value.owner.operands[0], divisor, fuel - 1)
149+
case arith_dialect.ConstantOp():
150+
return ir.IntegerAttr(def_op.value).value % divisor == 0
151+
case arith_dialect.MulIOp():
152+
return (_is_known_divisible(value.owner.operands[0], divisor, fuel // 2) or
153+
_is_known_divisible(value.owner.operands[1], divisor, (fuel + 1)// 2))
154+
return False
155+
156+
138157
@dataclasses.dataclass(frozen=True)
139158
class GPUMemoryRef(pallas_core.MemoryRef):
140159
transforms: Sequence[MemoryRefTransform] = ()
@@ -171,7 +190,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
171190
shape=self.to_gpu_transform().transform_shape(aval.shape)
172191
)
173192

174-
Index = slice | int | ir.Value
193+
Index = mgpu.DynamicSlice | slice | int | ir.Value
175194

176195
@dataclasses.dataclass(frozen=True)
177196
class TilingTransform(MemoryRefTransform):
@@ -218,16 +237,37 @@ def untransform_index(
218237
) -> tuple[tuple[Index, ...], state_types.Transform]:
219238
untiled_idxs = idxs[: -len(self.tiling)]
220239
tiled_idxs = idxs[-len(self.tiling) :]
221-
idxs_after_tiling = []
240+
idxs_after_tiling: list[Index] = []
222241
for idx, tile in zip(tiled_idxs, self.tiling):
223-
if not isinstance(idx, slice):
224-
raise NotImplementedError("Non-slice indices are not supported")
225-
assert isinstance(idx, slice)
226-
if idx.step is not None and idx.step != 1:
227-
raise NotImplementedError("Strided slices unsupported")
228-
if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile):
229-
raise ValueError("Non-empty slices must be tile aligned")
230-
idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile))
242+
if isinstance(idx, slice):
243+
if idx.step is not None and idx.step != 1:
244+
raise NotImplementedError("Strided slices unsupported")
245+
if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile):
246+
raise ValueError("Non-empty slices must be tile aligned")
247+
idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile))
248+
elif isinstance(idx, mgpu.DynamicSlice):
249+
if idx.length % tile:
250+
raise ValueError(
251+
f"Dynamic slice length ({idx.length}) is not divisible by the"
252+
f" tiling ({tile})"
253+
)
254+
if isinstance(idx.base, ir.Value):
255+
if not _is_known_divisible(idx.base, tile):
256+
raise ValueError(
257+
"Dynamic slice base index (which is a dynamic value) cannot be"
258+
f" statically proven to be divisible by the tiling ({tile})"
259+
)
260+
new_base = arith_dialect.divui(idx.base, mgpu.c(tile, idx.base.type))
261+
else:
262+
if idx.base % tile:
263+
raise ValueError(
264+
f"Dynamic slice base ({idx.base}) is not divisible by the"
265+
f" tiling ({tile})"
266+
)
267+
new_base = idx.base // tile
268+
idxs_after_tiling.append(mgpu.DynamicSlice(new_base, idx.length // tile))
269+
else:
270+
raise TypeError(f"Unsupported index type: {type(idx)}")
231271
return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self
232272

233273
def undo_to_gpu_transform(self) -> mgpu.MemRefTransform:
@@ -285,7 +325,7 @@ def untransform_index(
285325
self, idxs: tuple[Index, ...]
286326
) -> tuple[tuple[Index, ...], state_types.Transform]:
287327
removed_dims = [
288-
i for i, idx in enumerate(idxs) if not isinstance(idx, slice)
328+
i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds))
289329
]
290330
new_perm = tuple(
291331
p - sum(d < p for d in removed_dims)
@@ -358,18 +398,22 @@ def untransform_index(
358398
) -> tuple[tuple[Index, ...], state_types.Transform]:
359399
if not idxs:
360400
return idxs, self
361-
if not all(isinstance(idx, slice) for idx in idxs[-2:]):
401+
if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]):
362402
raise NotImplementedError(
363403
"Non-slice indices are not supported in 2 minormost dims"
364404
)
365405
last_idx = idxs[-1]
366-
assert isinstance(last_idx, slice)
367-
if last_idx.step is not None and last_idx.step != 1:
368-
raise NotImplementedError("Swizzled dims cannot be sliced")
369-
if (last_idx.start is not None and last_idx.start != 0) or (
370-
last_idx.stop is not None and last_idx.stop != self.swizzle
371-
):
372-
raise ValueError("Swizzled dims cannot be sliced")
406+
if isinstance(last_idx, mgpu.DynamicSlice):
407+
if last_idx.base != 0 or last_idx.length != self.swizzle:
408+
raise ValueError("Swizzled dims cannot be sliced")
409+
else:
410+
assert isinstance(last_idx, slice)
411+
if (
412+
(last_idx.step is not None and last_idx.step != 1)
413+
or (last_idx.start is not None and last_idx.start != 0)
414+
or (last_idx.stop is not None and last_idx.stop != self.swizzle)
415+
):
416+
raise ValueError("Swizzled dims cannot be sliced")
373417
return idxs, self
374418

375419

tests/pallas/mosaic_gpu_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,34 @@ def kernel(x_ref, o_ref):
11321132
x = jnp.arange(256, dtype=jnp.int32)
11331133
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256]))
11341134

1135+
# Not testing with warpgroup semantics, because we want to enforce a layout.
1136+
def test_tile_slicing(self):
1137+
shape = (256, 128)
1138+
block_spec = plgpu.GPUBlockSpec(
1139+
transforms=(
1140+
plgpu.TilingTransform((64, 64)),
1141+
plgpu.SwizzleTransform(128),
1142+
)
1143+
)
1144+
@functools.partial(
1145+
pl.pallas_call,
1146+
in_specs=[block_spec],
1147+
out_specs=block_spec,
1148+
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16),
1149+
)
1150+
def kernel(x_ref, o_ref):
1151+
def sum_tiles(row, acc):
1152+
row_slice = pl.ds(row * 64, 64)
1153+
for col in range(128 // 64):
1154+
acc += x_ref[row_slice, pl.ds(col * 64, 64)]
1155+
return acc
1156+
acc = plgpu.layout_cast(jnp.zeros((64, 64), jnp.uint16), plgpu.Layout.WGMMA)
1157+
o_ref[...] = _fori_loop(False, 0, 256 // 64, sum_tiles, acc)
1158+
1159+
x = jnp.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape)
1160+
y = x.reshape(256 // 64, 64, 128 // 64, 64).sum(axis=(0, 2), dtype=jnp.uint16)
1161+
np.testing.assert_array_equal(kernel(x), y)
1162+
11351163
def test_input_output_aliases(self):
11361164
# Note that we're writing to the input pointer, which should alias b_ptr.
11371165
def kernel(a_ref, b_ref):

0 commit comments

Comments
 (0)