Skip to content

Commit eb1f6ce

Browse files
committed
move DiffusionModel from experimental to networks
Stabilizes the DiffusionModel class. A deprecation warning for the DiffusionModel class in the experimental module was added.
1 parent ebbddce commit eb1f6ce

File tree

15 files changed

+38
-24
lines changed

15 files changed

+38
-24
lines changed

bayesflow/experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99

1010
from ..utils._docs import _add_imports_to_all
1111

12-
_add_imports_to_all(include_modules=["diffusion_model"])
12+
_add_imports_to_all()
Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
from .diffusion_model import DiffusionModel
2-
from bayesflow.experimental.diffusion_model.schedules import CosineNoiseSchedule
3-
from bayesflow.experimental.diffusion_model.schedules import EDMNoiseSchedule
4-
from bayesflow.experimental.diffusion_model.schedules import NoiseSchedule
5-
from .dispatch import find_noise_schedule
1+
from bayesflow.networks import DiffusionModel as StabilizedDiffusionModel
62

7-
from ...utils._docs import _add_imports_to_all
83

9-
_add_imports_to_all(include_modules=[])
4+
def DiffusionModel(*args, **kwargs):
5+
from warnings import warn
6+
7+
warn(
8+
"DiffusionModel has been stabilized and moved to bayesflow.networks. "
9+
"Please switch your imports to the new location. This reference will be "
10+
"removed in a future version.",
11+
FutureWarning,
12+
)
13+
return StabilizedDiffusionModel(*args, **kwargs)

bayesflow/networks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .consistency_models import ConsistencyModel
88
from .coupling_flow import CouplingFlow
99
from .deep_set import DeepSet
10+
from .diffusion_model import DiffusionModel
1011
from .flow_matching import FlowMatching
1112
from .inference_network import InferenceNetwork
1213
from .point_inference_network import PointInferenceNetwork
@@ -19,4 +20,4 @@
1920

2021
from ..utils._docs import _add_imports_to_all
2122

22-
_add_imports_to_all(include_modules=[])
23+
_add_imports_to_all(include_modules=["diffusion_model"])
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .diffusion_model import DiffusionModel
2+
from .schedules import CosineNoiseSchedule
3+
from .schedules import EDMNoiseSchedule
4+
from .schedules import NoiseSchedule
5+
from .dispatch import find_noise_schedule
6+
7+
from ...utils._docs import _add_imports_to_all
8+
9+
_add_imports_to_all(include_modules=[])

bayesflow/experimental/diffusion_model/diffusion_model.py renamed to bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import keras
55
from keras import ops
66

7-
from bayesflow.networks import InferenceNetwork
7+
from ..inference_network import InferenceNetwork
88
from bayesflow.types import Tensor, Shape
99
from bayesflow.utils import (
1010
expand_right_as,

0 commit comments

Comments
 (0)