Skip to content

Discrepancies between Ref and array indexing. #33322

@justinjfu

Description

@justinjfu

Description

The following cases work under JAX Array indexing but fail under Ref indexing.

class RefIndexingTest(jtu.JaxTestCase):
  def new_ref(self, x):
    # These tests all pass if we return x directly.
    return jax.new_ref(x)

  def test_negative_slice(self):
    x_np = np.zeros((10,))
    x = self.new_ref(jnp.asarray(x_np))
    indexer = slice(-2, -12, -1)
    self.assertEqual(x[indexer].shape, x_np[indexer].shape)
    self.assertArraysEqual(x[indexer], x_np[indexer])

  def test_slice_oob_clamping(self):
    x_np = np.zeros((10,))
    x = self.new_ref(jnp.asarray(x_np))
    indexer = slice(11, 12, 1)
    self.assertEqual(x[indexer].shape, x_np[indexer].shape)
    self.assertArraysEqual(x[indexer], x_np[indexer])

  def test_none_index(self):
    x_np = np.zeros((), dtype=np.int32)
    x = self.new_ref(jnp.asarray(x_np))
    indexer = None
    self.assertEqual(x[indexer].shape, x_np[indexer].shape)
    self.assertArraysEqual(x[indexer], x_np[indexer])

  def test_insert_none_axis(self):
    x_np = np.zeros(())
    x = self.new_ref(jnp.asarray(x_np))
    indexer = (..., None)
    self.assertEqual(x[indexer].shape, x_np[indexer].shape)
    self.assertArraysEqual(x[indexer], x_np[indexer])

Result:

======================================================================
ERROR: test_slice_oob_clamping (__main__.RefIndexingTest)
RefIndexingTest.test_slice_oob_clamping
----------------------------------------------------------------------

ValueError: Out of bound slice: start=10, dim=10.

======================================================================
ERROR: test_insert_none_axis (__main__.RefIndexingTest)
RefIndexingTest.test_insert_none_axis
----------------------------------------------------------------------

ValueError: `indices` must not be longer than `shape`: indices=(None,), shape=()

======================================================================
ERROR: test_negative_slice (__main__.RefIndexingTest)
RefIndexingTest.test_negative_slice
----------------------------------------------------------------------

ValueError: slice must have a step >= 1 (found: -1)

======================================================================
FAIL: test_none_index (__main__.RefIndexingTest)
RefIndexingTest.test_none_index
----------------------------------------------------------------------
AssertionError: Tuples differ: () != (1,)

Second tuple contains 1 additional elements.
First extra element 0:
1

- ()
+ (1,)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.8.1
jaxlib: 0.8.1
numpy: 2.3.4
python: 3.12.11
device info: cpu-1, 1 local devices"
process_count: 1
platform: Linux

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions