Skip to content

Commit 83d0952

Browse files
megrez-yliuGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Allow stride 0 in strided_load for broadcast.
PiperOrigin-RevId: 835372731
1 parent 2ef6bde commit 83d0952

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

jax/_src/state/indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class Slice:
4242
stride: int = 1
4343

4444
def __post_init__(self):
45-
if self.stride < 1:
46-
raise ValueError("`stride` must be >= 1.")
45+
if self.stride < 0:
46+
raise ValueError("`stride` must be >= 0.")
4747

4848
@property
4949
def is_dynamic_start(self):

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op,
735735

736736
template <typename Op>
737737
LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
738-
VectorType vector_ty) {
738+
VectorType vector_ty, int64_t min_stride) {
739739
auto indices = op.getIndices();
740740
auto strides = op.getStrides();
741741
if (memref_ty.getRank() != indices.size()) {
@@ -754,8 +754,9 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
754754
return failure();
755755
}
756756
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
757-
if (strides[i] < 1) {
758-
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1";
757+
if (strides[i] < min_stride) {
758+
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= "
759+
<< min_stride;
759760
return failure();
760761
}
761762
}
@@ -764,12 +765,13 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
764765

765766
LogicalResult StridedLoadOp::verify() {
766767
return verifyStridedOp<StridedLoadOp>(*this, getMemRefType(getBase()),
767-
getType());
768+
getType(), /*min_stride=*/0);
768769
}
769770

770771
LogicalResult StridedStoreOp::verify() {
771772
return verifyStridedOp<StridedStoreOp>(*this, getMemRefType(getBase()),
772-
getValueToStore().getType());
773+
getValueToStore().getType(),
774+
/*min_stride=*/1);
773775
}
774776

775777
template <typename Op>

tests/pallas/indexing_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,46 @@ def body(x_ref, y_ref1, y_ref2):
541541
y2[slices], expected, err_msg="Strided Store Error"
542542
)
543543

544+
@hp.given(hps.data())
545+
def test_load_and_broadcast_with_stride_0(self, data):
546+
if not jtu.if_cloud_tpu_at_least(2025, 11, 25):
547+
self.skipTest("Requires libtpu built after 2025-11-25")
548+
if self.INTERPRET:
549+
self.skipTest("TODO: fails in interpret mode.")
550+
dtype = jnp.float32
551+
rank = data.draw(hps.integers(min_value=2, max_value=4))
552+
shape = data.draw(hps.tuples(
553+
*(hps.integers(min_value=1, max_value=10) for _ in range(rank - 1))))
554+
shape = (*shape, 128)
555+
556+
strides = data.draw(hps.tuples(
557+
*(hps.sampled_from([0, 1]) for _ in range(rank - 1))))
558+
strides = (*strides, 1)
559+
560+
indices = []
561+
for i in range(rank):
562+
index = (data.draw(hps.integers(min_value=0, max_value=shape[i] - 1))
563+
if strides[i] == 0 else 0)
564+
indices.append(index)
565+
566+
def body(x_ref, y_ref):
567+
slices = tuple(
568+
pl.ds(i, l, s) for i, l, s in zip(indices, shape, strides)
569+
)
570+
y_ref[...] = x_ref[slices]
571+
572+
x = random.normal(random.key(33), shape, dtype=dtype)
573+
y = self.pallas_call(
574+
body,
575+
out_shape=jax.ShapeDtypeStruct(shape, dtype),
576+
)(x)
577+
slices = tuple(slice(i, l, 1) if s != 0 else slice(i, i + 1, 1)
578+
for i, l, s in zip(indices, shape, strides))
579+
580+
expected = jnp.broadcast_to(x[slices], shape)
581+
self.assertAllClose(y, expected)
582+
583+
544584
def test_load_with_dynamic_2nd_minor_index(self):
545585
if pltpu is None:
546586
self.skipTest("No TPU module available.")

0 commit comments

Comments
 (0)