Skip to content

Import issue depending on JAX version #3

@HatPdotS

Description

@HatPdotS

I am using the package to implement a Fast rotation in a function in X-ray crystallography. When importing s2ball.transform you are triggering an import error for the jax config.

     Traceback (most recent call last):                                                                                                              
       File "<stdin>", line 1, in <module>                                                                                                           
       File "/das/work/units/LBR-FEL/p17490/Peter/Library/s2ball/s2ball/s2ball/transform/__init__.py", line 1, in <module>                           
         from . import harmonic                                                                                                                      
       File "/das/work/units/LBR-FEL/p17490/Peter/Library/s2ball/s2ball/s2ball/transform/harmonic.py", line 3, in <module>                           
         from s2ball.construct import matrix                                                                                                         
       File "/das/work/units/LBR-FEL/p17490/Peter/Library/s2ball/s2ball/s2ball/construct/__init__.py", line 3, in <module>                           
         from . import wavelet_constructor                                                                                                           
       File "/das/work/units/LBR-FEL/p17490/Peter/Library/s2ball/s2ball/s2ball/construct/wavelet_constructor.py", line 5, in <module>                
         from s2ball.wavelets.helper_functions import *                                                                                              
       File "/das/work/units/LBR-FEL/p17490/Peter/Library/s2ball/s2ball/s2ball/wavelets/__init__.py", line 2, in <module>                            
         from . import tiling                                                                                                                        
       File "/das/work/units/LBR-FEL/p17490/Peter/Library/s2ball/s2ball/s2ball/wavelets/tiling.py", line 7, in <module>                              
         from jax.config import config                                                                                                               
     ModuleNotFoundError: No module named 'jax.config'

This error is jax version specific. The fix would be to be explicite in compatible jax versions or just band-aid it as a temporary solution.

try:
  from jax.config import config
  config.update("jax_enable_x64", True)


except:
  import jax
  jax.config.update("jax_enable_x64", True)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions