|
29 | 29 | from jax._src import dtypes |
30 | 30 | from jax._src import effects |
31 | 31 | from jax._src import tree_util |
| 32 | +from jax._src.lib.mlir.dialects import arith as arith_dialect |
32 | 33 | from jax._src.pallas import core as pallas_core |
| 34 | +from jax._src.state import discharge as state_discharge |
33 | 35 | from jax._src.state import indexing |
34 | 36 | from jax._src.state import types as state_types |
35 | | -from jax._src.state import discharge as state_discharge |
36 | 37 | import jax.experimental.mosaic.gpu as mgpu |
37 | 38 | import jax.numpy as jnp |
38 | 39 | from jaxlib.mlir import ir |
@@ -135,6 +136,24 @@ def cmap_body(): |
135 | 136 | return wrapper |
136 | 137 |
|
137 | 138 |
|
| 139 | +def _is_known_divisible(value, divisor, fuel=10) -> bool: |
| 140 | + """Returns True if the value is statically known to be divisible by the divisor.""" |
| 141 | + if fuel < 0: |
| 142 | + return False |
| 143 | + if not isinstance(value.owner, ir.Operation): |
| 144 | + return False |
| 145 | + def_op = value.owner.opview |
| 146 | + match def_op: |
| 147 | + case arith_dialect.IndexCastOp(): |
| 148 | + return _is_known_divisible(value.owner.operands[0], divisor, fuel - 1) |
| 149 | + case arith_dialect.ConstantOp(): |
| 150 | + return ir.IntegerAttr(def_op.value).value % divisor == 0 |
| 151 | + case arith_dialect.MulIOp(): |
| 152 | + return (_is_known_divisible(value.owner.operands[0], divisor, fuel // 2) or |
| 153 | + _is_known_divisible(value.owner.operands[1], divisor, (fuel + 1)// 2)) |
| 154 | + return False |
| 155 | + |
| 156 | + |
138 | 157 | @dataclasses.dataclass(frozen=True) |
139 | 158 | class GPUMemoryRef(pallas_core.MemoryRef): |
140 | 159 | transforms: Sequence[MemoryRefTransform] = () |
@@ -171,7 +190,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: |
171 | 190 | shape=self.to_gpu_transform().transform_shape(aval.shape) |
172 | 191 | ) |
173 | 192 |
|
174 | | -Index = slice | int | ir.Value |
| 193 | +Index = mgpu.DynamicSlice | slice | int | ir.Value |
175 | 194 |
|
176 | 195 | @dataclasses.dataclass(frozen=True) |
177 | 196 | class TilingTransform(MemoryRefTransform): |
@@ -218,16 +237,37 @@ def untransform_index( |
218 | 237 | ) -> tuple[tuple[Index, ...], state_types.Transform]: |
219 | 238 | untiled_idxs = idxs[: -len(self.tiling)] |
220 | 239 | tiled_idxs = idxs[-len(self.tiling) :] |
221 | | - idxs_after_tiling = [] |
| 240 | + idxs_after_tiling: list[Index] = [] |
222 | 241 | for idx, tile in zip(tiled_idxs, self.tiling): |
223 | | - if not isinstance(idx, slice): |
224 | | - raise NotImplementedError("Non-slice indices are not supported") |
225 | | - assert isinstance(idx, slice) |
226 | | - if idx.step is not None and idx.step != 1: |
227 | | - raise NotImplementedError("Strided slices unsupported") |
228 | | - if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): |
229 | | - raise ValueError("Non-empty slices must be tile aligned") |
230 | | - idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) |
| 242 | + if isinstance(idx, slice): |
| 243 | + if idx.step is not None and idx.step != 1: |
| 244 | + raise NotImplementedError("Strided slices unsupported") |
| 245 | + if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): |
| 246 | + raise ValueError("Non-empty slices must be tile aligned") |
| 247 | + idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) |
| 248 | + elif isinstance(idx, mgpu.DynamicSlice): |
| 249 | + if idx.length % tile: |
| 250 | + raise ValueError( |
| 251 | + f"Dynamic slice length ({idx.length}) is not divisible by the" |
| 252 | + f" tiling ({tile})" |
| 253 | + ) |
| 254 | + if isinstance(idx.base, ir.Value): |
| 255 | + if not _is_known_divisible(idx.base, tile): |
| 256 | + raise ValueError( |
| 257 | + "Dynamic slice base index (which is a dynamic value) cannot be" |
| 258 | + f" statically proven to be divisible by the tiling ({tile})" |
| 259 | + ) |
| 260 | + new_base = arith_dialect.divui(idx.base, mgpu.c(tile, idx.base.type)) |
| 261 | + else: |
| 262 | + if idx.base % tile: |
| 263 | + raise ValueError( |
| 264 | + f"Dynamic slice base ({idx.base}) is not divisible by the" |
| 265 | + f" tiling ({tile})" |
| 266 | + ) |
| 267 | + new_base = idx.base // tile |
| 268 | + idxs_after_tiling.append(mgpu.DynamicSlice(new_base, idx.length // tile)) |
| 269 | + else: |
| 270 | + raise TypeError(f"Unsupported index type: {type(idx)}") |
231 | 271 | return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self |
232 | 272 |
|
233 | 273 | def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: |
@@ -285,7 +325,7 @@ def untransform_index( |
285 | 325 | self, idxs: tuple[Index, ...] |
286 | 326 | ) -> tuple[tuple[Index, ...], state_types.Transform]: |
287 | 327 | removed_dims = [ |
288 | | - i for i, idx in enumerate(idxs) if not isinstance(idx, slice) |
| 328 | + i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) |
289 | 329 | ] |
290 | 330 | new_perm = tuple( |
291 | 331 | p - sum(d < p for d in removed_dims) |
@@ -358,18 +398,22 @@ def untransform_index( |
358 | 398 | ) -> tuple[tuple[Index, ...], state_types.Transform]: |
359 | 399 | if not idxs: |
360 | 400 | return idxs, self |
361 | | - if not all(isinstance(idx, slice) for idx in idxs[-2:]): |
| 401 | + if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): |
362 | 402 | raise NotImplementedError( |
363 | 403 | "Non-slice indices are not supported in 2 minormost dims" |
364 | 404 | ) |
365 | 405 | last_idx = idxs[-1] |
366 | | - assert isinstance(last_idx, slice) |
367 | | - if last_idx.step is not None and last_idx.step != 1: |
368 | | - raise NotImplementedError("Swizzled dims cannot be sliced") |
369 | | - if (last_idx.start is not None and last_idx.start != 0) or ( |
370 | | - last_idx.stop is not None and last_idx.stop != self.swizzle |
371 | | - ): |
372 | | - raise ValueError("Swizzled dims cannot be sliced") |
| 406 | + if isinstance(last_idx, mgpu.DynamicSlice): |
| 407 | + if last_idx.base != 0 or last_idx.length != self.swizzle: |
| 408 | + raise ValueError("Swizzled dims cannot be sliced") |
| 409 | + else: |
| 410 | + assert isinstance(last_idx, slice) |
| 411 | + if ( |
| 412 | + (last_idx.step is not None and last_idx.step != 1) |
| 413 | + or (last_idx.start is not None and last_idx.start != 0) |
| 414 | + or (last_idx.stop is not None and last_idx.stop != self.swizzle) |
| 415 | + ): |
| 416 | + raise ValueError("Swizzled dims cannot be sliced") |
373 | 417 | return idxs, self |
374 | 418 |
|
375 | 419 |
|
|
0 commit comments