-
Notifications
You must be signed in to change notification settings - Fork 76
Closed
Description
System Info:
Ubuntu 20.04,
Python 3.11,
Nvidia3080ti
Jax Versions:
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35
Getting the following error while migrating from PRNGKey to key.
Here is the full traceback:
Traceback (most recent call last):
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
app.run(main)
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
score = score_submission_on_workload(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
timing, metrics = train_once(workload, workload_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 221, in train_once
input_queue = workload._build_input_queue(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 155, in _build_input_queue
ds = _build_mnist_dataset(
^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
~~~~~~~~^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 646, in _getitem
return lax_numpy._rewriting_take(self, item)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11411, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11420, in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11528, in _index_to_gather
idx = _canonicalize_tuple_index(len(x_shape), idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11852, in _canonicalize_tuple_index
raise IndexError(
IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.
Metadata
Metadata
Assignees
Labels
No labels