Skip to content

Commit 2fd5a90

Browse files
authored
Merge pull request #440 from bayesflow-org/feat-stochastic-sampler
Feat stochastic sampler
2 parents 59a349b + 9ed482d commit 2fd5a90

File tree

3 files changed

+203
-13
lines changed

3 files changed

+203
-13
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
layer_kwargs,
1919
weighted_mean,
2020
integrate,
21+
integrate_stochastic,
2122
)
2223

2324

@@ -373,7 +374,7 @@ class DiffusionModel(InferenceNetwork):
373374
}
374375

375376
INTEGRATE_DEFAULT_CONFIG = {
376-
"method": "euler",
377+
"method": "euler", # or euler_maruyama
377378
"steps": 100,
378379
}
379380

@@ -529,6 +530,7 @@ def velocity(
529530
time: float | Tensor,
530531
conditions: Tensor = None,
531532
training: bool = False,
533+
stochastic_solver: bool = False,
532534
clip_x: bool = False,
533535
) -> Tensor:
534536
# calculate the current noise level and transform into correct shape
@@ -548,13 +550,30 @@ def velocity(
548550
# convert x to score
549551
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
550552

551-
# compute velocity for the ODE depending on the noise schedule
553+
# compute velocity f, g of the SDE or ODE
552554
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
553-
out = f - 0.5 * g_squared * score
554555

555-
# todo: for the SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
556+
if stochastic_solver:
557+
# for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW
558+
out = f - g_squared * score
559+
else:
560+
# for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt
561+
out = f - 0.5 * g_squared * score
562+
556563
return out
557564

565+
def compute_diffusion_term(
566+
self,
567+
xz: Tensor,
568+
time: float | Tensor,
569+
training: bool = False,
570+
) -> Tensor:
571+
# calculate the current noise level and transform into correct shape
572+
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
573+
log_snr_t = keras.ops.broadcast_to(log_snr_t, keras.ops.shape(xz)[:-1] + (1,))
574+
g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t)
575+
return ops.sqrt(g_squared)
576+
558577
def _velocity_trace(
559578
self,
560579
xz: Tensor,
@@ -586,6 +605,9 @@ def _forward(
586605
| self.integrate_kwargs
587606
| kwargs
588607
)
608+
if integrate_kwargs["method"] == "euler_maruyama":
609+
raise ValueError("Stoachastic methods are not supported for forward integration.")
610+
589611
if density:
590612

591613
def deltas(time, xz):
@@ -636,6 +658,8 @@ def _inverse(
636658
| kwargs
637659
)
638660
if density:
661+
if integrate_kwargs["method"] == "euler_maruyama":
662+
raise ValueError("Stoachastic methods are not supported for density computation.")
639663

640664
def deltas(time, xz):
641665
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
@@ -656,11 +680,23 @@ def deltas(time, xz):
656680
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
657681

658682
state = {"xz": z}
659-
state = integrate(
660-
deltas,
661-
state,
662-
**integrate_kwargs,
663-
)
683+
if integrate_kwargs["method"] == "euler_maruyama":
684+
685+
def diffusion(time, xz):
686+
return {"xz": self.compute_diffusion_term(xz, time=time, training=training)}
687+
688+
state = integrate_stochastic(
689+
deltas,
690+
diffusion,
691+
state,
692+
**integrate_kwargs,
693+
)
694+
else:
695+
state = integrate(
696+
deltas,
697+
state,
698+
**integrate_kwargs,
699+
)
664700

665701
x = state["xz"]
666702
return x

bayesflow/utils/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
repo_url,
3030
)
3131
from .hparam_utils import find_batch_size, find_memory_budget
32-
from .integrate import (
33-
integrate,
34-
)
32+
from .integrate import integrate, integrate_stochastic
3533
from .io import (
3634
pickle_load,
3735
format_bytes,

bayesflow/utils/integrate.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import keras
55

66
import numpy as np
7-
from typing import Literal
7+
from typing import Literal, Union, List
88

99
from bayesflow.types import Tensor
1010
from bayesflow.utils import filter_kwargs
@@ -293,3 +293,159 @@ def integrate(
293293
return integrate_scheduled(fn, state, steps, method, **kwargs)
294294
else:
295295
raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})")
296+
297+
298+
def euler_maruyama_step(
299+
drift_fn: Callable,
300+
diffusion_fn: Callable,
301+
state: dict[str, ArrayLike],
302+
time: ArrayLike,
303+
step_size: ArrayLike,
304+
noise: dict[str, ArrayLike] = None,
305+
tolerance: ArrayLike = 1e-6,
306+
min_step_size: ArrayLike = -float("inf"),
307+
max_step_size: ArrayLike = float("inf"),
308+
use_adaptive_step_size: bool = False,
309+
) -> (dict[str, ArrayLike], ArrayLike, ArrayLike):
310+
"""
311+
Performs a single Euler-Maruyama step for stochastic differential equations.
312+
313+
Args:
314+
drift_fn: Function that computes the drift term.
315+
diffusion_fn: Function that computes the diffusion term.
316+
state: Dictionary containing the current state.
317+
time: Current time.
318+
step_size: Size of the integration step.
319+
noise: Dictionary of noise terms for each state variable.
320+
tolerance: Error tolerance for adaptive step size.
321+
min_step_size: Minimum allowed step size.
322+
max_step_size: Maximum allowed step size.
323+
use_adaptive_step_size: Whether to use adaptive step sizing.
324+
325+
Returns:
326+
Tuple of (new_state, new_time, new_step_size).
327+
"""
328+
# Compute drift term
329+
drift = drift_fn(time, **filter_kwargs(state, drift_fn))
330+
331+
# Compute diffusion term
332+
diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn))
333+
334+
# Generate noise if not provided
335+
if noise is None:
336+
noise = {}
337+
for key in diffusion.keys():
338+
shape = keras.ops.shape(diffusion[key])
339+
noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))
340+
341+
# Check if diffusion and noise have the same keys
342+
if set(diffusion.keys()) != set(noise.keys()):
343+
raise ValueError("Keys of diffusion terms and noise do not match.")
344+
345+
if use_adaptive_step_size:
346+
# Perform a half-step to estimate error
347+
intermediate_state = state.copy()
348+
for key in drift.keys():
349+
intermediate_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key])
350+
351+
# Compute drift and diffusion at intermediate state
352+
intermediate_drift = drift_fn(time + step_size, **filter_kwargs(intermediate_state, drift_fn))
353+
354+
# Compute error estimate
355+
error_terms = []
356+
for key in drift.keys():
357+
error = keras.ops.norm(intermediate_drift[key] - drift[key], ord=2, axis=-1)
358+
error_terms.append(error)
359+
360+
intermediate_error = keras.ops.stack(error_terms)
361+
new_step_size = step_size * tolerance / (intermediate_error + 1e-9)
362+
363+
# Apply constraints to step size
364+
new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size)
365+
366+
# Consolidate step size
367+
new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size)))
368+
else:
369+
new_step_size = step_size
370+
371+
# Apply updates using Euler-Maruyama formula: dx = f(x)dt + g(x)dW
372+
new_state = state.copy()
373+
for key in drift.keys():
374+
if key in diffusion:
375+
new_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key])
376+
else:
377+
# If no diffusion term for this variable, apply deterministic update
378+
new_state[key] = state[key] + step_size * drift[key]
379+
380+
new_time = time + step_size
381+
382+
return new_state, new_time, new_step_size
383+
384+
385+
def integrate_stochastic(
386+
drift_fn: Callable,
387+
diffusion_fn: Callable,
388+
state: dict[str, ArrayLike],
389+
start_time: ArrayLike,
390+
stop_time: ArrayLike,
391+
steps: int,
392+
method: str = "euler_maruyama",
393+
seed: int = None,
394+
**kwargs,
395+
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, List[ArrayLike]]]]:
396+
"""
397+
Integrates a stochastic differential equation from start_time to stop_time.
398+
399+
Args:
400+
drift_fn: Function that computes the drift term.
401+
diffusion_fn: Function that computes the diffusion term.
402+
state: Dictionary containing the initial state.
403+
start_time: Starting time for integration.
404+
stop_time: Ending time for integration.
405+
steps: Number of integration steps.
406+
method: Integration method to use ('euler_maruyama').
407+
seed: Random seed for noise generation.
408+
**kwargs: Additional arguments to pass to the step function.
409+
410+
Returns:
411+
If return_noise is False, returns the final state dictionary.
412+
If return_noise is True, returns a tuple of (final_state, noise_history).
413+
"""
414+
if steps <= 0:
415+
raise ValueError("Number of steps must be positive.")
416+
417+
# Set random seed if provided
418+
if seed is not None:
419+
keras.random.set_seed(seed)
420+
421+
# Select step function based on method
422+
match method:
423+
case "euler_maruyama":
424+
step_fn = euler_maruyama_step
425+
case str() as name:
426+
raise ValueError(f"Unknown integration method name: {name!r}")
427+
case other:
428+
raise TypeError(f"Invalid integration method: {other!r}")
429+
430+
# Prepare step function with partial application
431+
step_fn = partial(step_fn, drift_fn, diffusion_fn, **kwargs)
432+
step_size = (stop_time - start_time) / steps
433+
434+
time = start_time
435+
436+
def body(_loop_var, _loop_state):
437+
_state, _time = _loop_state
438+
439+
# Generate noise for this step
440+
_noise = {}
441+
for key in _state.keys():
442+
shape = keras.ops.shape(_state[key])
443+
_noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))
444+
445+
# Perform integration step
446+
_state, _time, _ = step_fn(_state, _time, step_size, noise=_noise)
447+
448+
return _state, _time
449+
450+
state, time = keras.ops.fori_loop(0, steps, body, (state, time))
451+
return state

0 commit comments

Comments
 (0)