Skip to content

Commit 51497a1

Browse files
Support shape -1 for slice op in the jax backend (#21501)
* Support shape -1 for slice op in the jax backend. * Fix comments. * Update names * Address Gemini reviews.
1 parent 129e3d7 commit 51497a1

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

keras/src/backend/jax/core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,13 @@ def scatter_update(inputs, indices, updates):
300300

301301

302302
def slice(inputs, start_indices, shape):
303-
return jax.lax.dynamic_slice(inputs, start_indices, shape)
303+
# If shape[i] is -1, all remaining elements in dimension i are included in
304+
# the slice.
305+
final_shape = tuple(
306+
inputs.shape[i] - start_indices[i] if s == -1 else s
307+
for i, s in enumerate(shape)
308+
)
309+
return jax.lax.dynamic_slice(inputs, start_indices, final_shape)
304310

305311

306312
def slice_update(inputs, start_indices, updates):

keras/src/ops/core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,20 @@ def call(self, inputs, start_indices):
397397
return backend.core.slice(inputs, start_indices, self.shape)
398398

399399
def compute_output_spec(self, inputs, start_indices):
400-
return KerasTensor(self.shape, dtype=inputs.dtype)
400+
if any(s == -1 for s in self.shape) and isinstance(
401+
start_indices, KerasTensor
402+
):
403+
raise ValueError(
404+
"When using -1 in `shape`, `start_indices` should not be a "
405+
"KerasTensor. "
406+
)
407+
# If self.shape[i] is -1, all remaining elements in dimension i are
408+
# included in the slice.
409+
final_shape = tuple(
410+
inputs.shape[i] - start_indices[i] if s == -1 else s
411+
for i, s in enumerate(self.shape)
412+
)
413+
return KerasTensor(final_shape, dtype=inputs.dtype)
401414

402415

403416
@keras_export("keras.ops.slice")

keras/src/ops/core_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,19 @@ def test_slice(self):
265265
shape = (2, 2)
266266
self.assertEqual(core.slice(inputs, start_indices, shape).shape, (2, 2))
267267

268+
def test_slice_negative_one_shape(self):
269+
inputs = KerasTensor(shape=(3, 3), dtype="float32")
270+
start_indices = (1, 1)
271+
shape = (-1, -1)
272+
self.assertEqual(core.slice(inputs, start_indices, shape).shape, (2, 2))
273+
274+
def test_slice_negative_one_shape_raises(self):
275+
inputs = KerasTensor(shape=(3, 3), dtype="float32")
276+
start_indices = KerasTensor(shape=(2,), dtype="int32")
277+
shape = (-1, -1)
278+
with self.assertRaises(ValueError):
279+
core.slice(inputs, start_indices, shape)
280+
268281
def test_slice_update(self):
269282
inputs = KerasTensor((4, 4))
270283
start_indices = KerasTensor((2,))

0 commit comments

Comments
 (0)