Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bayesflow/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

from ..utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=["diffusion_model"])
_add_imports_to_all()
18 changes: 11 additions & 7 deletions bayesflow/experimental/diffusion_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from .diffusion_model import DiffusionModel
from bayesflow.experimental.diffusion_model.schedules import CosineNoiseSchedule
from bayesflow.experimental.diffusion_model.schedules import EDMNoiseSchedule
from bayesflow.experimental.diffusion_model.schedules import NoiseSchedule
from .dispatch import find_noise_schedule
from bayesflow.networks import DiffusionModel as StabilizedDiffusionModel

from ...utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
def DiffusionModel(*args, **kwargs):
from warnings import warn

Check warning on line 5 in bayesflow/experimental/diffusion_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/diffusion_model/__init__.py#L5

Added line #L5 was not covered by tests

warn(

Check warning on line 7 in bayesflow/experimental/diffusion_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/diffusion_model/__init__.py#L7

Added line #L7 was not covered by tests
"DiffusionModel has been stabilized and moved to bayesflow.networks. "
"Please switch your imports to the new location. This reference will be "
"removed in a future version.",
FutureWarning,
)
return StabilizedDiffusionModel(*args, **kwargs)

Check warning on line 13 in bayesflow/experimental/diffusion_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/diffusion_model/__init__.py#L13

Added line #L13 was not covered by tests
3 changes: 2 additions & 1 deletion bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .consistency_models import ConsistencyModel
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .diffusion_model import DiffusionModel
from .flow_matching import FlowMatching
from .inference_network import InferenceNetwork
from .point_inference_network import PointInferenceNetwork
Expand All @@ -19,4 +20,4 @@

from ..utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
_add_imports_to_all(include_modules=["diffusion_model"])
9 changes: 9 additions & 0 deletions bayesflow/networks/diffusion_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .diffusion_model import DiffusionModel
from .schedules import CosineNoiseSchedule
from .schedules import EDMNoiseSchedule
from .schedules import NoiseSchedule
from .dispatch import find_noise_schedule

from ...utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import keras
from keras import ops

from bayesflow.networks import InferenceNetwork
from ..inference_network import InferenceNetwork
from bayesflow.types import Tensor, Shape
from bayesflow.utils import (
expand_right_as,
Expand Down
2 changes: 1 addition & 1 deletion examples/Likelihood_Estimation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
"source": [
"workflow = bf.BasicWorkflow(\n",
" simulator=simulator,\n",
" inference_network=bf.experimental.DiffusionModel(),\n",
" inference_network=bf.networks.DiffusionModel(),\n",
" inference_variables=\"x\",\n",
" inference_conditions=\"theta\",\n",
" initial_learning_rate=1e-3,\n",
Expand Down
12 changes: 6 additions & 6 deletions tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@pytest.fixture()
def diffusion_model_edm_F():
from bayesflow.experimental import DiffusionModel
from bayesflow.networks import DiffusionModel

return DiffusionModel(
subnet=MLP([8, 8]),
Expand All @@ -17,7 +17,7 @@ def diffusion_model_edm_F():

@pytest.fixture()
def diffusion_model_edm_velocity():
from bayesflow.experimental import DiffusionModel
from bayesflow.networks import DiffusionModel

return DiffusionModel(
subnet=MLP([8, 8]),
Expand All @@ -29,7 +29,7 @@ def diffusion_model_edm_velocity():

@pytest.fixture()
def diffusion_model_edm_noise():
from bayesflow.experimental import DiffusionModel
from bayesflow.networks import DiffusionModel

return DiffusionModel(
subnet=MLP([8, 8]),
Expand All @@ -41,7 +41,7 @@ def diffusion_model_edm_noise():

@pytest.fixture()
def diffusion_model_cosine_F():
from bayesflow.experimental import DiffusionModel
from bayesflow.networks import DiffusionModel

return DiffusionModel(
subnet=MLP([8, 8]),
Expand All @@ -53,7 +53,7 @@ def diffusion_model_cosine_F():

@pytest.fixture()
def diffusion_model_cosine_velocity():
from bayesflow.experimental import DiffusionModel
from bayesflow.networks import DiffusionModel

return DiffusionModel(
subnet=MLP([8, 8]),
Expand All @@ -65,7 +65,7 @@ def diffusion_model_cosine_velocity():

@pytest.fixture()
def diffusion_model_cosine_noise():
from bayesflow.experimental import DiffusionModel
from bayesflow.networks import DiffusionModel

return DiffusionModel(
subnet=MLP([8, 8]),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_networks/test_diffusion_model/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

@pytest.fixture()
def cosine_noise_schedule():
from bayesflow.experimental.diffusion_model.schedules import CosineNoiseSchedule
from bayesflow.networks.diffusion_model.schedules import CosineNoiseSchedule

return CosineNoiseSchedule(min_log_snr=-12, max_log_snr=12, shift=0.1, weighting="likelihood_weighting")


@pytest.fixture()
def edm_noise_schedule():
from bayesflow.experimental.diffusion_model.schedules import EDMNoiseSchedule
from bayesflow.networks.diffusion_model.schedules import EDMNoiseSchedule

return EDMNoiseSchedule(sigma_data=10.0, sigma_min=1e-5, sigma_max=85.0)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_
# cycle-consistency means the forward and inverse methods are inverses of each other
import bayesflow as bf

if isinstance(generative_inference_network, bf.experimental.DiffusionModel):
if isinstance(generative_inference_network, bf.networks.DiffusionModel):
pytest.skip(reason="test unstable for untrained diffusion models")
try:
forward_output, forward_log_density = generative_inference_network(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from bayesflow.utils import find_inference_network, find_distribution, find_network, find_summary_network
from bayesflow.experimental.diffusion_model import find_noise_schedule
from bayesflow.networks.diffusion_model import find_noise_schedule

# --- Tests for find__network.py ---

Expand Down Expand Up @@ -247,7 +247,7 @@ def test_find_summary_network_invalid_type():


def test_find_noise_schedule_by_name():
from bayesflow.experimental.diffusion_model.schedules import CosineNoiseSchedule, EDMNoiseSchedule
from bayesflow.networks.diffusion_model.schedules import CosineNoiseSchedule, EDMNoiseSchedule

schedule = find_noise_schedule("cosine")
assert isinstance(schedule, CosineNoiseSchedule)
Expand All @@ -262,7 +262,7 @@ def test_find_noise_schedule_unknown_name():


def test_pass_noise_schedule():
from bayesflow.experimental.diffusion_model.schedules.noise_schedule import NoiseSchedule
from bayesflow.networks.diffusion_model.schedules.noise_schedule import NoiseSchedule

class CustomNoiseSchedule(NoiseSchedule):
def __init__(self):
Expand All @@ -282,7 +282,7 @@ def derivative_log_snr(self, log_snr_t, training):


def test_pass_noise_schedule_type():
from bayesflow.experimental.diffusion_model.schedules import EDMNoiseSchedule
from bayesflow.networks.diffusion_model.schedules import EDMNoiseSchedule

schedule = find_noise_schedule(EDMNoiseSchedule, sigma_data=10.0)
assert isinstance(schedule, EDMNoiseSchedule)
Expand Down