Skip to content

Commit c11b615

Browse files
committed
minor cleanup of refactory
1 parent 9a8db95 commit c11b615

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

bayesflow/experimental/diffusion_model/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .diffusion_model import DiffusionModel
2-
from bayesflow.experimental.diffusion_model.schedules.cosine_noise_schedule import CosineNoiseSchedule
2+
from bayesflow.experimental.diffusion_model.schedules import CosineNoiseSchedule
3+
from bayesflow.experimental.diffusion_model.schedules import EDMNoiseSchedule
4+
from bayesflow.experimental.diffusion_model.schedules import NoiseSchedule
35
from .dispatch import find_noise_schedule
46

57
from ...utils._docs import _add_imports_to_all

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
self._prediction_type = prediction_type
106106
self._loss_type = loss_type
107107

108-
self.schedule_kwargs = schedule_kwargs or {}
108+
schedule_kwargs = schedule_kwargs or {}
109109
self.noise_schedule = find_noise_schedule(noise_schedule, **self.schedule_kwargs)
110110
self.noise_schedule.validate()
111111

@@ -148,7 +148,6 @@ def get_config(self):
148148
"noise_schedule": self.noise_schedule,
149149
"prediction_type": self._prediction_type,
150150
"loss_type": self._loss_type,
151-
"schedule_kwargs": self.schedule_kwargs,
152151
"integrate_kwargs": self.integrate_kwargs,
153152
}
154153
return base_config | serialize(config)
@@ -194,8 +193,9 @@ def convert_prediction_to_x(
194193
return x1 * z + x2 * pred
195194
elif self._prediction_type == "x":
196195
return pred
197-
else:
196+
elif self._prediction_type == "score":
198197
return (z + sigma_t**2 * pred) / alpha_t
198+
raise ValueError(f"Unknown prediction type {self._prediction_type}.")
199199

200200
def velocity(
201201
self,
@@ -320,12 +320,9 @@ def _forward(
320320
training: bool = False,
321321
**kwargs,
322322
) -> Tensor | tuple[Tensor, Tensor]:
323-
integrate_kwargs = {
324-
**self.integrate_kwargs,
325-
"start_time": kwargs.pop("start_time", 0.0),
326-
"stop_time": kwargs.pop("stop_time", 1.0),
327-
**kwargs,
328-
}
323+
integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0}
324+
integrate_kwargs = integrate_kwargs | self.integrate_kwargs
325+
integrate_kwargs = integrate_kwargs | kwargs
329326

330327
if integrate_kwargs["method"] == "euler_maruyama":
331328
raise ValueError("Stochastic methods are not supported for forward integration.")
@@ -373,12 +370,9 @@ def _inverse(
373370
training: bool = False,
374371
**kwargs,
375372
) -> Tensor | tuple[Tensor, Tensor]:
376-
integrate_kwargs = {
377-
**self.integrate_kwargs,
378-
"start_time": kwargs.pop("start_time", 1.0),
379-
"stop_time": kwargs.pop("stop_time", 0.0),
380-
**kwargs,
381-
}
373+
integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0}
374+
integrate_kwargs = integrate_kwargs | self.integrate_kwargs
375+
integrate_kwargs = integrate_kwargs | kwargs
382376
if density:
383377
if integrate_kwargs["method"] == "euler_maruyama":
384378
raise ValueError("Stochastic methods are not supported for density computation.")

bayesflow/experimental/diffusion_model/schedules/edm_noise_schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool = False) -> Tenso
101101

102102
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
103103
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
104-
# for F-prediction: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2)
104+
# for F-loss: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2)
105105
return 1 + ops.exp(-log_snr_t) / ops.square(self.sigma_data)
106106

107107
def get_config(self):

0 commit comments

Comments
 (0)