Skip to content

Bug Report: Incompatibility with JAX 0.7.0+ due to deprecated API usage #10

@rafaelkaufmann

Description

@rafaelkaufmann

Description

PGMax is currently incompatible with JAX version 0.7.0 and later due to the use of a deprecated API. The deprecation warning was introduced in JAX 0.7.0, with the API slated for removal in JAX 0.8.0.

Error

AttributeError: jax.lib.xla_bridge.get_backend is deprecated and will be removed in JAX v0.8.0; use jax.extend.backend.get_backend.

Full traceback shows the error originates from:

File "/usr/local/lib/python3.12/site-packages/pgmax/infer/inferer.py", line 66, in __post_init__
    if jax.lib.xla_bridge.get_backend().platform == "tpu":  # pragma: no cover

Location

File: pgmax/infer/inferer.py
Line: 66
Class: InfererContext.__post_init__()

Current Code

def __post_init__(self):
    if jax.lib.xla_bridge.get_backend().platform == "tpu":  # pragma: no cover
      warnings.warn(
          "PGMax is not optimized for the TPU backend. Please consider using"
          " GPUs!"
      )

Proposed Fix

Replace the deprecated API with the new jax.extend.backend.get_backend():

def __post_init__(self):
    import jax.extend  # Explicit import required
    if jax.extend.backend.get_backend().platform == "tpu":  # pragma: no cover
      warnings.warn(
          "PGMax is not optimized for the TPU backend. Please consider using"
          " GPUs!"
      )

JAX Changelog Reference

From the JAX v0.7.0 changelog:

jax.lib.xla_bridge.get_backend is deprecated in JAX v0.7.0 and will be removed in JAX v0.8.0; use jax.extend.backend.get_backend

Impact

  • Users cannot use PGMax with JAX 0.7.0 or later without encountering this error
  • This blocks users from receiving JAX security updates and new features
  • Forces dependency pinning to JAX <0.7.0 in downstream projects

Environment

  • PGMax version: 0.6.1 (from main branch)
  • JAX version: 0.7.0+
  • Python version: 3.12

Workaround

Currently, users must pin JAX to versions below 0.7.0:

jax<0.7.0
jaxlib<0.7.0

Note: The fix requires adding import jax.extend as per JAX's deprecation notice: "please note that you must import jax.extend explicitly."

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions