-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
I would have thought these three lines would give the same output, with Jax handling the type conversions.
import jax.numpy as jnp
numbers = jnp.arange(10000*10).reshape(10000, 10)
print(numbers[1000, jnp.array(0, dtype=jnp.int32)])
print(numbers[jnp.array(1000, dtype=jnp.int32), jnp.array(0, dtype=jnp.int8)])
print(numbers[1000, jnp.array(0, dtype=jnp.int8)])
However, the output is
10000
10000
0
Is this expected to fail silently here? Also, I am not sure this requires a Python scalar... it seemed to be causing issues in a larger piece of code when the first index is also a jax array, but I haven't been able to get a small reproducer of that yet.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.0
jaxlib: 0.6.0
numpy: 2.3.0
python: 3.13.3 (main, Apr 9 2025, 04:03:52) [Clang 20.1.0 ]
device info: NVIDIA GeForce RTX 4050 Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='bendodge', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working