|
24 | 24 | } |
25 | 25 | ## VIASH END |
26 | 26 |
|
| 27 | +# --- Fix for CuDNN version mismatch ------------------------------------------ |
| 28 | +# JAX wheels (e.g. jax[cuda12]) come with their own CUDA/cuDNN libraries. |
| 29 | +# However, some systems set LD_LIBRARY_PATH to point to an older system-wide |
| 30 | +# CUDA/cuDNN (e.g. 9.1). That can override JAX's bundled libraries and cause |
| 31 | +# errors like: |
| 32 | +# "Loaded runtime CuDNN library: 9.1.0 but source was compiled with: 9.8.0" |
| 33 | +# |
| 34 | +# To prevent that, we remove LD_LIBRARY_PATH before importing JAX so it uses |
| 35 | +# its own compatible, built-in CUDA/cuDNN stack. |
| 36 | +# ----------------------------------------------------------------------------- |
| 37 | +import os |
| 38 | +print("LD_LIBRARY_PATH before unset:", os.environ.get("LD_LIBRARY_PATH"), flush=True) |
| 39 | +os.environ.pop("LD_LIBRARY_PATH", None) |
| 40 | + |
| 41 | +# gpu check |
| 42 | +import jax |
| 43 | +import jaxlib |
| 44 | +print("GPU check", flush=True) |
| 45 | +print("jax:", jax.__version__, flush=True) |
| 46 | +print("jaxlib:", jaxlib.__version__, flush=True) |
| 47 | +print("backend:", jax.default_backend(), flush=True) |
| 48 | +print("devices:", jax.devices(), flush=True) |
| 49 | +print("LD_LIBRARY_PATH:", os.environ.get("LD_LIBRARY_PATH"), flush=True) |
| 50 | + |
| 51 | + |
27 | 52 | # Optional parameter check: For this specific annotation method the par['input_spatial_normalized_counts'] and par['input_scrnaseq_reference'] are required |
28 | 53 | assert par['input_spatial_normalized_counts'] is not None, 'Spatial input is required for this annotation method.' |
29 | 54 | assert par['input_scrnaseq_reference'] is not None, 'Single cell input is required for this annotation method.' |
|
0 commit comments