Skip to content

IndexError: Too many indices: 0-dimensional array indexed with 1 regular index, while migrating from jax.random.PRNGKey to jax.random.key #815

@init-22

Description

@init-22

System Info:
Ubuntu 20.04,
Python 3.11,
Nvidia3080ti

Jax Versions:
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35

Getting the following error while migrating from PRNGKey to key.

Here is the full traceback:

Traceback (most recent call last):
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
    app.run(main)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
    score = score_submission_on_workload(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 221, in train_once
    input_queue = workload._build_input_queue(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 155, in _build_input_queue
    ds = _build_mnist_dataset(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
    ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
                                                 ~~~~~~~~^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 646, in _getitem
    return lax_numpy._rewriting_take(self, item)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11411, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11420, in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11528, in _index_to_gather
    idx = _canonicalize_tuple_index(len(x_shape), idx)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11852, in _canonicalize_tuple_index
    raise IndexError(
IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions