Skip to content

Commit f4f0d11

Browse files
committed
improve dispatch
1 parent 0f96265 commit f4f0d11

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
noise_schedule : Literal['edm', 'cosine'], dict or type, optional
7575
The noise schedule used for the diffusion process. Can be "cosine" or "edm" or a custom noise schedule.
7676
You can also pass a dictionary with the configuration for the noise schedule, e.g.,
77-
{'type': cosine, 's_shift_cosine': 1.0}
77+
{'name': cosine, 's_shift_cosine': 1.0}
7878
Default is "edm".
7979
prediction_type: Literal['velocity', 'noise', 'F'], optional
8080
The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).

bayesflow/experimental/diffusion_model/dispatch.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def _(name: str, *args, **kwargs):
2929

3030
@find_noise_schedule.register
3131
def _(config: dict, *args, **kwargs):
32-
name = config.get("type", "").lower()
33-
params = {k: v for k, v in config.items() if k != "type"}
32+
name = config.get("name", "").lower()
33+
params = {k: v for k, v in config.items() if k != "name"}
3434
match name:
3535
case "cosine":
3636
from .noise_schedules import CosineNoiseSchedule
@@ -46,9 +46,6 @@ def _(config: dict, *args, **kwargs):
4646

4747
@find_noise_schedule.register
4848
def _(cls: type, *args, **kwargs):
49-
# Lazily import NoiseSchedule class and compare
50-
from .noise_schedules import NoiseSchedule
51-
5249
if issubclass(cls, NoiseSchedule):
5350
return cls(*args, **kwargs)
5451
raise TypeError(f"Expected subclass of NoiseSchedule, got {cls}")

0 commit comments

Comments
 (0)