Skip to content

Commit 79be9ab

Browse files
committed
fix stochastic sampler
1 parent 5ca609f commit 79be9ab

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,8 @@ def __init__(
374374
subnet: str | type = "mlp",
375375
integrate_kwargs: dict[str, any] = None,
376376
subnet_kwargs: dict[str, any] = None,
377-
noise_schedule: str | NoiseSchedule = "cosine",
378-
prediction_type: str = "velocity",
377+
noise_schedule: str | NoiseSchedule = "edm",
378+
prediction_type: str = "F",
379379
**kwargs,
380380
):
381381
"""
@@ -398,10 +398,10 @@ def __init__(
398398
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
399399
noise_schedule : str or NoiseSchedule, optional
400400
The noise schedule used for the diffusion process. Can be "linear", "cosine", or "edm".
401-
Default is "cosine".
401+
Default is "edm".
402402
prediction_type: str, optional
403403
The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).
404-
Default is "velocity".
404+
Default is "F".
405405
**kwargs
406406
Additional keyword arguments passed to the subnet and other components.
407407
"""
@@ -425,10 +425,6 @@ def __init__(
425425
if prediction_type not in ["noise", "velocity", "F"]: # F is EDM
426426
raise ValueError(f"Unknown prediction type: {prediction_type}")
427427
self._prediction_type = prediction_type
428-
if noise_schedule.name == "edm_noise_schedule" and prediction_type != "F":
429-
warnings.warn(
430-
"EDM noise schedule is build for F-prediction. Consider using F-prediction instead.",
431-
)
432428
self._loss_type = kwargs.get("loss_type", "noise")
433429
if self._loss_type not in ["noise", "velocity", "F"]:
434430
raise ValueError(f"Unknown loss type: {self._loss_type}")

bayesflow/utils/integrate.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -391,15 +391,18 @@ def integrate_stochastic(
391391

392392
# Prepare step function with partial application
393393
step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, seed=seed, **kwargs)
394-
step_size = (stop_time - start_time) / steps
395394

395+
step_size = (stop_time - start_time) / steps
396396
time = start_time
397+
current_state = state.copy()
398+
399+
# keras.ops.fori_loop does not support keras seed generator in jax
400+
for i in range(steps):
401+
# Execute the step with the specific seed for this step
402+
current_state, time = step_fn(
403+
state=current_state,
404+
time=time,
405+
step_size=step_size,
406+
)
397407

398-
def body(_loop_var, _loop_state):
399-
_state, _time = _loop_state
400-
_state, _time = step_fn(state=_state, time=_time, step_size=step_size)
401-
402-
return _state, _time
403-
404-
state, time = keras.ops.fori_loop(0, steps, body, (state, time))
405-
return state
408+
return current_state

0 commit comments

Comments
 (0)