-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
The following cases work under JAX Array indexing but fail under Ref indexing.
class RefIndexingTest(jtu.JaxTestCase):
def new_ref(self, x):
# These tests all pass if we return x directly.
return jax.new_ref(x)
def test_negative_slice(self):
x_np = np.zeros((10,))
x = self.new_ref(jnp.asarray(x_np))
indexer = slice(-2, -12, -1)
self.assertEqual(x[indexer].shape, x_np[indexer].shape)
self.assertArraysEqual(x[indexer], x_np[indexer])
def test_slice_oob_clamping(self):
x_np = np.zeros((10,))
x = self.new_ref(jnp.asarray(x_np))
indexer = slice(11, 12, 1)
self.assertEqual(x[indexer].shape, x_np[indexer].shape)
self.assertArraysEqual(x[indexer], x_np[indexer])
def test_none_index(self):
x_np = np.zeros((), dtype=np.int32)
x = self.new_ref(jnp.asarray(x_np))
indexer = None
self.assertEqual(x[indexer].shape, x_np[indexer].shape)
self.assertArraysEqual(x[indexer], x_np[indexer])
def test_insert_none_axis(self):
x_np = np.zeros(())
x = self.new_ref(jnp.asarray(x_np))
indexer = (..., None)
self.assertEqual(x[indexer].shape, x_np[indexer].shape)
self.assertArraysEqual(x[indexer], x_np[indexer])Result:
======================================================================
ERROR: test_slice_oob_clamping (__main__.RefIndexingTest)
RefIndexingTest.test_slice_oob_clamping
----------------------------------------------------------------------
ValueError: Out of bound slice: start=10, dim=10.
======================================================================
ERROR: test_insert_none_axis (__main__.RefIndexingTest)
RefIndexingTest.test_insert_none_axis
----------------------------------------------------------------------
ValueError: `indices` must not be longer than `shape`: indices=(None,), shape=()
======================================================================
ERROR: test_negative_slice (__main__.RefIndexingTest)
RefIndexingTest.test_negative_slice
----------------------------------------------------------------------
ValueError: slice must have a step >= 1 (found: -1)
======================================================================
FAIL: test_none_index (__main__.RefIndexingTest)
RefIndexingTest.test_none_index
----------------------------------------------------------------------
AssertionError: Tuples differ: () != (1,)
Second tuple contains 1 additional elements.
First extra element 0:
1
- ()
+ (1,)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.8.1
jaxlib: 0.8.1
numpy: 2.3.4
python: 3.12.11
device info: cpu-1, 1 local devices"
process_count: 1
platform: Linux
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working