Skip to content

Commit c1d5c86

Browse files
committed
Add priority ordering of backends
1 parent 7bc1e7d commit c1d5c86

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

bayesflow/__init__.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@ def setup():
1212
if "KERAS_BACKEND" not in os.environ:
1313
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
1414
class Backend:
15-
def __init__(self, display_name, package_name, env_name, install_url):
15+
def __init__(self, display_name, package_name, env_name, install_url, priority):
1616
self.display_name = display_name
1717
self.package_name = package_name
1818
self.env_name = env_name
1919
self.install_url = install_url
20+
self.priority = priority
2021

2122
backends = [
22-
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation"),
23-
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/"),
24-
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install"),
23+
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation", 0),
24+
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/", 1),
25+
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install", 2),
2526
]
2627

2728
found_backends = []
@@ -42,18 +43,19 @@ def __init__(self, display_name, package_name, env_name, install_url):
4243
message += "https://keras.io/getting_started/#configuring-your-backend"
4344

4445
raise ImportError(message)
45-
elif len(found_backends) > 1:
46-
message = "Multiple backends found:\n"
47-
for backend in found_backends:
48-
message += f"- {backend.display_name}\n"
49-
message += "\n"
5046

51-
message += (
52-
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
53-
)
54-
message += "https://keras.io/getting_started/#configuring-your-backend"
47+
if len(found_backends) > 1:
48+
import warnings
5549

56-
raise ImportError(message)
50+
found_backends.sort(key=lambda b: b.priority)
51+
chosen_backend = found_backends[0]
52+
53+
warnings.warn(
54+
f"Multiple Keras-compatible backends detected ({', '.join(b.display_name for b in found_backends)}).\n"
55+
f"Defaulting to {chosen_backend.display_name}.\n"
56+
"To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n"
57+
"See: https://keras.io/getting_started/#configuring-your-backend"
58+
)
5759
else:
5860
os.environ["KERAS_BACKEND"] = found_backends[0].env_name
5961

0 commit comments

Comments
 (0)