Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions jax/_src/pallas/mosaic/sc_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,6 @@ def _load_lowering_rule(
assert isinstance(ref_aval, state.AbstractRef)
[out_aval] = ctx.avals_out
assert isinstance(out_aval, jax_core.ShapedArray)
in_smem = ref_aval.memory_space is tpu_core.MemorySpace.SMEM
if in_smem:
if out_aval.ndim:
raise NotImplementedError("Get can only load scalars from SMEM")
else:
_check_aval_is_supported("Get", out_aval)

transforms = list(jax.tree.unflatten(tree, flat_transforms))
if not transforms or not isinstance(transforms[-1], indexing.NDIndexer):
Expand All @@ -350,11 +344,16 @@ def _load_lowering_rule(
"Get only supports slices with stride 1, got {strides}"
)

if in_smem:
if not out_aval.ndim:
if mask is not None:
raise NotImplementedError("Get does not support masked loads from SMEM")
raise NotImplementedError("Get does not support masked scalar loads")
return memref.load(ref, starts)

if ref_aval.memory_space is tpu_core.MemorySpace.SMEM:
raise NotImplementedError("Get can only load scalars from SMEM")
else:
_check_aval_is_supported("Get", out_aval)

vec_type = ir.VectorType.get(
out_aval.shape, _dtype_to_ir_type(ref_aval.dtype)
)
Expand All @@ -378,13 +377,6 @@ def _store_lowering_rule(
[out_aval] = ctx.avals_out
assert isinstance(out_aval, jax_core.ShapedArray)

in_smem = ref_aval.memory_space is tpu_core.MemorySpace.SMEM
if in_smem:
if out_aval.ndim:
raise NotImplementedError("Swap can only store scalars to SMEM")
else:
_check_aval_is_supported("Swap", out_aval)

transforms = list(jax.tree.unflatten(tree, flat_transforms))
if not transforms or not isinstance(transforms[-1], indexing.NDIndexer):
ref_shape = (
Expand All @@ -409,17 +401,22 @@ def _store_lowering_rule(
"Swap only supports slices with stride 1, got {strides}"
)

if in_smem:
if not out_aval.ndim:
if mask is not None:
raise NotImplementedError("Swap does not support masked stores to SMEM")
raise NotImplementedError("Swap does not support masked scalar stores")
if add:
# TODO(slebedev): We can use memref.atomic_rmw here, but the SC compiler
# doesn't support it yet.
raise NotImplementedError("Swap does not support atomic adds to SMEM")
raise NotImplementedError("Swap does not support atomic scalar adds")
old_val = memref.load(ref, starts)
memref.store(val, ref, starts)
return old_val

if ref_aval.memory_space is tpu_core.MemorySpace.SMEM:
raise NotImplementedError("Swap can only store scalars to SMEM")
else:
_check_aval_is_supported("Swap", out_aval)

vec_type = ir.VectorType.get(
out_aval.shape, _dtype_to_ir_type(ref_aval.dtype)
)
Expand Down
16 changes: 16 additions & 0 deletions tests/pallas/tpu_sparsecore_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,22 @@ def kernel(x_ref, o_ref):
x_ref.at[pl.ds(2, 8)], mask=jnp.arange(8) % 2 == 0)
np.testing.assert_array_equal(kernel(x)[5:13:2], x[2:6])

def test_scalar_load_store(self):

@vector_subcore_kernel(
in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),),
out_specs=pl.BlockSpec(memory_space=pltpu.VMEM),
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
scratch_shapes=(pltpu.VMEM((1,), jnp.int32),),
)
def kernel(x_ref, o_ref, tmp_ref):
pltpu.sync_copy(x_ref, tmp_ref)
o_ref[...] = lax.broadcast(tmp_ref[0], o_ref.shape)

np.testing.assert_array_equal(
kernel(jnp.ones((1,), jnp.int32)), jnp.ones((8,), jnp.int32)
)

@parameterized.named_parameters(
("mixed", [0, 0, 1, 0, 1, 0, 0, 0], 2),
("all_zero", [0, 0, 0, 0, 0, 0, 0, 0], 8),
Expand Down
Loading