diff --git a/bayesflow/experimental/__init__.py b/bayesflow/experimental/__init__.py index adfba709e..8364aab92 100644 --- a/bayesflow/experimental/__init__.py +++ b/bayesflow/experimental/__init__.py @@ -1,5 +1,6 @@ from .cif import CIF from .continuous_time_consistency_model import ContinuousTimeConsistencyModel +from .free_form_flow import FreeFormFlow from ..utils._docs import _add_imports_to_all diff --git a/bayesflow/networks/free_form_flow/__init__.py b/bayesflow/experimental/free_form_flow/__init__.py similarity index 100% rename from bayesflow/networks/free_form_flow/__init__.py rename to bayesflow/experimental/free_form_flow/__init__.py diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py similarity index 99% rename from bayesflow/networks/free_form_flow/free_form_flow.py rename to bayesflow/experimental/free_form_flow/free_form_flow.py index aa21dd63a..36382362e 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -14,7 +14,7 @@ deserialize_value_or_type, ) -from ..inference_network import InferenceNetwork +from bayesflow.networks import InferenceNetwork @serializable(package="networks.free_form_flow") diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index d11d16d9c..d29576819 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -2,7 +2,6 @@ from .coupling_flow import CouplingFlow from .deep_set import DeepSet from .flow_matching import FlowMatching -from .free_form_flow import FreeFormFlow from .inference_network import InferenceNetwork from .mlp import MLP from .lstnet import LSTNet diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 2310fbe33..7910574ab 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -50,14 +50,14 @@ def coupling_flow_subnet(subnet): @pytest.fixture() def free_form_flow(): - from bayesflow.networks import FreeFormFlow + from bayesflow.experimental import FreeFormFlow return FreeFormFlow() @pytest.fixture() def free_form_flow_subnet(subnet): - from bayesflow.networks import FreeFormFlow + from bayesflow.experimental import FreeFormFlow return FreeFormFlow(encoder_subnet=subnet, decoder_subnet=subnet)