Skip to content

Upgrade JAX to 0.4.31 to access jax.lax.map(batch_size=X) #188

@georgematheos

Description

@georgematheos

@eightysteele

On the gen3d branch, we are running into issues upgrading JAX to 0.4.31. We want to do this because per the top answer here, starting with JAX 0.4.31, the function jax.map.lax supports a batch_size argument. We would like to include this batch_size parameters at this line. One way to test that this works is to ensure that all the tests in tests/gen3d/ pass after adding batch_size=1000 at the indicated line. More specifically, if this line runs without failing, we should be good to go. [In this blob it looks like we commented this out -- oops! But this test should be uncommented and work.]

When we try changing this line to request jaxlib ==0.4.31, and run pixi install, we get

(gpu) georgematheos@pixi-vm-2:~/b3d$ pixi install
 WARN Defined custom mapping channel https://conda.anaconda.org/conda-forge/ is missing from project channels
  × failed to solve the conda requirements of 'gpu' 'linux-64'
  ╰─▶ Cannot solve the request because of: The following packages are incompatible
      ├─ pytorch ==2.3.0 cuda12* can be installed with any of the following options:
      │  └─ pytorch 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 would require
      │     └─ cudnn >=8.9.7.29,<9.0a0, which can be installed with any of the following options:
      │        └─ cudnn 8.9.7.29
      └─ jaxlib ==0.4.31 cuda12* cannot be installed because there are no viable options:
         └─ jaxlib 0.4.31 | 0.4.31 | 0.4.31 would require
            └─ cudnn >=9.2.1.18,<10.0a0, which cannot be installed because there are no viable options:
               └─ cudnn 9.2.1.18 | 9.2.1.18, which conflicts with the versions reported above.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions