Skip to content

Commit 1ad4095

Browse files
committed
add automatic backend detection and selection
1 parent 47d2766 commit 1ad4095

File tree

1 file changed

+76
-20
lines changed

1 file changed

+76
-20
lines changed

bayesflow/__init__.py

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,62 @@
1-
from . import (
2-
approximators,
3-
adapters,
4-
augmentations,
5-
datasets,
6-
diagnostics,
7-
distributions,
8-
experimental,
9-
networks,
10-
simulators,
11-
utils,
12-
workflows,
13-
wrappers,
14-
)
15-
16-
from .adapters import Adapter
17-
from .approximators import ContinuousApproximator, PointApproximator
18-
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
19-
from .simulators import make_simulator
20-
from .workflows import BasicWorkflow
1+
# ruff: noqa: E402
2+
# disable E402 to allow for setup code before importing any internals (which could import keras)
213

224

235
def setup():
246
# perform any necessary setup without polluting the namespace
7+
import os
8+
from importlib.util import find_spec
9+
10+
issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md"
11+
12+
if "KERAS_BACKEND" not in os.environ:
13+
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
14+
class Backend:
15+
def __init__(self, display_name, package_name, env_name, install_url):
16+
self.display_name = display_name
17+
self.package_name = package_name
18+
self.env_name = env_name
19+
self.install_url = install_url
20+
21+
backends = [
22+
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation"),
23+
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/"),
24+
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install"),
25+
]
26+
27+
found_backends = []
28+
for backend in backends:
29+
if find_spec(backend.package_name) is not None:
30+
found_backends.append(backend)
31+
32+
if not found_backends:
33+
message = "No suitable backend found. Please install one of the following:\n"
34+
for backend in backends:
35+
message += f"{backend.display_name}\n"
36+
message += "\n"
37+
38+
message += f"If you continue to see this error, please file an bug report at {issue_url}.\n"
39+
message += (
40+
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
41+
)
42+
message += "https://keras.io/getting_started/#configuring-your-backend"
43+
44+
raise ImportError(message)
45+
elif len(found_backends) > 1:
46+
message = "Multiple backends found:\n"
47+
for backend in found_backends:
48+
message += f"- {backend.display_name}\n"
49+
message += "\n"
50+
51+
message += (
52+
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
53+
)
54+
message += "https://keras.io/getting_started/#configuring-your-backend"
55+
56+
raise ImportError(message)
57+
else:
58+
os.environ["KERAS_BACKEND"] = found_backends[0].env_name
59+
2560
import keras
2661
import logging
2762

@@ -60,3 +95,24 @@ def setup():
6095
# call and clean up namespace
6196
setup()
6297
del setup
98+
99+
from . import (
100+
approximators,
101+
adapters,
102+
augmentations,
103+
datasets,
104+
diagnostics,
105+
distributions,
106+
experimental,
107+
networks,
108+
simulators,
109+
utils,
110+
workflows,
111+
wrappers,
112+
)
113+
114+
from .adapters import Adapter
115+
from .approximators import ContinuousApproximator, PointApproximator
116+
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
117+
from .simulators import make_simulator
118+
from .workflows import BasicWorkflow

0 commit comments

Comments
 (0)