diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py index 029c87f39c0d..788758e13fbc 100644 --- a/jax/_src/pallas/mosaic/sc_lowering.py +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -323,12 +323,6 @@ def _load_lowering_rule( assert isinstance(ref_aval, state.AbstractRef) [out_aval] = ctx.avals_out assert isinstance(out_aval, jax_core.ShapedArray) - in_smem = ref_aval.memory_space is tpu_core.MemorySpace.SMEM - if in_smem: - if out_aval.ndim: - raise NotImplementedError("Get can only load scalars from SMEM") - else: - _check_aval_is_supported("Get", out_aval) transforms = list(jax.tree.unflatten(tree, flat_transforms)) if not transforms or not isinstance(transforms[-1], indexing.NDIndexer): @@ -350,11 +344,16 @@ def _load_lowering_rule( "Get only supports slices with stride 1, got {strides}" ) - if in_smem: + if not out_aval.ndim: if mask is not None: - raise NotImplementedError("Get does not support masked loads from SMEM") + raise NotImplementedError("Get does not support masked scalar loads") return memref.load(ref, starts) + if ref_aval.memory_space is tpu_core.MemorySpace.SMEM: + raise NotImplementedError("Get can only load scalars from SMEM") + else: + _check_aval_is_supported("Get", out_aval) + vec_type = ir.VectorType.get( out_aval.shape, _dtype_to_ir_type(ref_aval.dtype) ) @@ -378,13 +377,6 @@ def _store_lowering_rule( [out_aval] = ctx.avals_out assert isinstance(out_aval, jax_core.ShapedArray) - in_smem = ref_aval.memory_space is tpu_core.MemorySpace.SMEM - if in_smem: - if out_aval.ndim: - raise NotImplementedError("Swap can only store scalars to SMEM") - else: - _check_aval_is_supported("Swap", out_aval) - transforms = list(jax.tree.unflatten(tree, flat_transforms)) if not transforms or not isinstance(transforms[-1], indexing.NDIndexer): ref_shape = ( @@ -409,17 +401,22 @@ def _store_lowering_rule( "Swap only supports slices with stride 1, got {strides}" ) - if in_smem: + if not out_aval.ndim: if mask is not None: - raise NotImplementedError("Swap does not support masked stores to SMEM") + raise NotImplementedError("Swap does not support masked scalar stores") if add: # TODO(slebedev): We can use memref.atomic_rmw here, but the SC compiler # doesn't support it yet. - raise NotImplementedError("Swap does not support atomic adds to SMEM") + raise NotImplementedError("Swap does not support atomic scalar adds") old_val = memref.load(ref, starts) memref.store(val, ref, starts) return old_val + if ref_aval.memory_space is tpu_core.MemorySpace.SMEM: + raise NotImplementedError("Swap can only store scalars to SMEM") + else: + _check_aval_is_supported("Swap", out_aval) + vec_type = ir.VectorType.get( out_aval.shape, _dtype_to_ir_type(ref_aval.dtype) ) diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 931c500aa783..2b7f81b20187 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -728,6 +728,22 @@ def kernel(x_ref, o_ref): x_ref.at[pl.ds(2, 8)], mask=jnp.arange(8) % 2 == 0) np.testing.assert_array_equal(kernel(x)[5:13:2], x[2:6]) + def test_scalar_load_store(self): + + @vector_subcore_kernel( + in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), + scratch_shapes=(pltpu.VMEM((1,), jnp.int32),), + ) + def kernel(x_ref, o_ref, tmp_ref): + pltpu.sync_copy(x_ref, tmp_ref) + o_ref[...] = lax.broadcast(tmp_ref[0], o_ref.shape) + + np.testing.assert_array_equal( + kernel(jnp.ones((1,), jnp.int32)), jnp.ones((8,), jnp.int32) + ) + @parameterized.named_parameters( ("mixed", [0, 0, 1, 0, 1, 0, 0, 0], 2), ("all_zero", [0, 0, 0, 0, 0, 0, 0, 0], 8),