Skip to content

Commit f3acfa9

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu] FragentedArray.foreach() can now optionally return a new array
PiperOrigin-RevId: 700708119
1 parent 03b6945 commit f3acfa9

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,15 +1243,29 @@ def select(self, on_true, on_false):
12431243
lambda t, p, f: arith.select(p, t, f), self, on_false,
12441244
)
12451245

1246-
def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
1246+
def foreach(
1247+
self,
1248+
fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None],
1249+
*,
1250+
create_array=False,
1251+
is_signed=None,
1252+
):
12471253
"""Call a function for each value and index."""
12481254
index = ir.IndexType.get()
1249-
for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True):
1250-
assert len(idx) == len(self.shape), (idx, self.shape)
1255+
new_regs = None
1256+
if create_array:
1257+
new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type))
1258+
for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True):
1259+
reg = self.registers[reg_idx]
1260+
assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape)
12511261
[elems] = ir.VectorType(reg.type).shape
12521262
for i in range(elems):
12531263
i = c(i, index)
1254-
fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i)))
1264+
val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i)))
1265+
if create_array:
1266+
new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i)
1267+
1268+
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
12551269

12561270
def store_untiled(self, ref: ir.Value):
12571271
if not ir.MemRefType.isinstance(ref.type):

tests/mosaic/gpu_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,39 @@ def kernel(ctx, dst, _):
13611361
rhs = rhs = 0 if rhs_is_literal else iota + 1
13621362
np.testing.assert_array_equal(result, op(iota, rhs))
13631363

1364+
def test_foreach(self):
1365+
dtype = jnp.int32
1366+
swizzle = 128
1367+
tile = 64, swizzle // jnp.dtype(dtype).itemsize
1368+
shape = 128, 192
1369+
tiled_shape = mgpu.tile_shape(shape, tile)
1370+
mlir_dtype = utils.dtype_to_ir_type(dtype)
1371+
cst = 9999
1372+
def causal(val, idx):
1373+
row, col = idx
1374+
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
1375+
return arith.select(mask, val, c(cst, mlir_dtype))
1376+
1377+
tiling = mgpu.TileTransform(tile)
1378+
def kernel(ctx, dst, smem):
1379+
x = iota_tensor(shape[0], shape[1], dtype)
1380+
x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem)
1381+
mgpu.commit_shared()
1382+
ctx.async_copy(src_ref=smem, dst_ref=dst)
1383+
ctx.await_async_copy(0)
1384+
1385+
iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
1386+
result = mgpu.as_gpu_kernel(
1387+
kernel,
1388+
(1, 1, 1),
1389+
(128, 1, 1),
1390+
(),
1391+
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
1392+
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
1393+
)()
1394+
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
1395+
np.testing.assert_array_equal(result, expected)
1396+
13641397
@parameterized.product(
13651398
op=[operator.and_, operator.or_, operator.xor],
13661399
dtype=[jnp.uint32],

0 commit comments

Comments
 (0)