@@ -12,16 +12,17 @@ def setup():
1212 if "KERAS_BACKEND" not in os .environ :
1313 # check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
1414 class Backend :
15- def __init__ (self , display_name , package_name , env_name , install_url ):
15+ def __init__ (self , display_name , package_name , env_name , install_url , priority ):
1616 self .display_name = display_name
1717 self .package_name = package_name
1818 self .env_name = env_name
1919 self .install_url = install_url
20+ self .priority = priority
2021
2122 backends = [
22- Backend ("JAX" , "jax" , "jax" , "https://docs.jax.dev/en/latest/quickstart.html#installation" ),
23- Backend ("PyTorch" , "torch" , "torch" , "https://pytorch.org/get-started/locally/" ),
24- Backend ("TensorFlow" , "tensorflow" , "tensorflow" , "https://www.tensorflow.org/install" ),
23+ Backend ("JAX" , "jax" , "jax" , "https://docs.jax.dev/en/latest/quickstart.html#installation" , 0 ),
24+ Backend ("PyTorch" , "torch" , "torch" , "https://pytorch.org/get-started/locally/" , 1 ),
25+ Backend ("TensorFlow" , "tensorflow" , "tensorflow" , "https://www.tensorflow.org/install" , 2 ),
2526 ]
2627
2728 found_backends = []
@@ -42,18 +43,19 @@ def __init__(self, display_name, package_name, env_name, install_url):
4243 message += "https://keras.io/getting_started/#configuring-your-backend"
4344
4445 raise ImportError (message )
45- elif len (found_backends ) > 1 :
46- message = "Multiple backends found:\n "
47- for backend in found_backends :
48- message += f"- { backend .display_name } \n "
49- message += "\n "
5046
51- message += (
52- "You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n "
53- )
54- message += "https://keras.io/getting_started/#configuring-your-backend"
47+ if len (found_backends ) > 1 :
48+ import warnings
5549
56- raise ImportError (message )
50+ found_backends .sort (key = lambda b : b .priority )
51+ chosen_backend = found_backends [0 ]
52+
53+ warnings .warn (
54+ f"Multiple Keras-compatible backends detected ({ ', ' .join (b .display_name for b in found_backends )} ).\n "
55+ f"Defaulting to { chosen_backend .display_name } .\n "
56+ "To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n "
57+ "See: https://keras.io/getting_started/#configuring-your-backend"
58+ )
5759 else :
5860 os .environ ["KERAS_BACKEND" ] = found_backends [0 ].env_name
5961
0 commit comments