Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 76 additions & 20 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,62 @@
from . import (
approximators,
adapters,
augmentations,
datasets,
diagnostics,
distributions,
experimental,
networks,
simulators,
utils,
workflows,
wrappers,
)

from .adapters import Adapter
from .approximators import ContinuousApproximator, PointApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow
# ruff: noqa: E402
# disable E402 to allow for setup code before importing any internals (which could import keras)


def setup():
# perform any necessary setup without polluting the namespace
import os
from importlib.util import find_spec

issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md"

if "KERAS_BACKEND" not in os.environ:
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
class Backend:
def __init__(self, display_name, package_name, env_name, install_url):
self.display_name = display_name
self.package_name = package_name
self.env_name = env_name
self.install_url = install_url

backends = [
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation"),
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/"),
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install"),
]

found_backends = []
for backend in backends:
if find_spec(backend.package_name) is not None:
found_backends.append(backend)

if not found_backends:
message = "No suitable backend found. Please install one of the following:\n"
for backend in backends:
message += f"{backend.display_name}\n"
message += "\n"

message += f"If you continue to see this error, please file an bug report at {issue_url}.\n"
message += (
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
)
message += "https://keras.io/getting_started/#configuring-your-backend"

raise ImportError(message)
elif len(found_backends) > 1:
message = "Multiple backends found:\n"
for backend in found_backends:
message += f"- {backend.display_name}\n"
message += "\n"

message += (
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
)
message += "https://keras.io/getting_started/#configuring-your-backend"

raise ImportError(message)
else:
os.environ["KERAS_BACKEND"] = found_backends[0].env_name

import keras
import logging

Expand Down Expand Up @@ -60,3 +95,24 @@ def setup():
# call and clean up namespace
setup()
del setup

from . import (
approximators,
adapters,
augmentations,
datasets,
diagnostics,
distributions,
experimental,
networks,
simulators,
utils,
workflows,
wrappers,
)

from .adapters import Adapter
from .approximators import ContinuousApproximator, PointApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow