Skip to content

Commit 79cf26f

Browse files
committed
Finish mixture, add docs
1 parent a814aa5 commit 79cf26f

File tree

8 files changed

+218
-85
lines changed

8 files changed

+218
-85
lines changed

bayesflow/distributions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .distribution import Distribution
88
from .diagonal_normal import DiagonalNormal
99
from .diagonal_student_t import DiagonalStudentT
10+
from .mixture import Mixture
1011

1112
from .find_distribution import find_distribution
1213

bayesflow/distributions/diagonal_normal.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
import keras
2-
from keras.saving import register_keras_serializable as serializable
3-
41
import math
2+
53
import numpy as np
64

5+
import keras
6+
77
from bayesflow.types import Shape, Tensor
88
from bayesflow.utils.decorators import allow_batch_size
9+
from bayesflow.utils.serialization import serializable, serialize
910

1011
from .distribution import Distribution
1112

1213

13-
@serializable(package="bayesflow.distributions")
14+
@serializable
1415
class DiagonalNormal(Distribution):
1516
"""Implements a backend-agnostic diagonal Gaussian distribution."""
1617

@@ -65,10 +66,8 @@ def __init__(
6566
def build(self, input_shape: Shape) -> None:
6667
self.dim = int(input_shape[-1])
6768

68-
# convert to tensor and broadcast if necessary
6969
self.mean = keras.ops.broadcast_to(self.mean, (self.dim,))
7070
self.mean = keras.ops.cast(self.mean, "float32")
71-
7271
self.std = keras.ops.broadcast_to(self.std, (self.dim,))
7372
self.std = keras.ops.cast(self.std, "float32")
7473

@@ -77,24 +76,24 @@ def build(self, input_shape: Shape) -> None:
7776
)
7877

7978
if self.use_learnable_parameters:
80-
mean = self.mean
81-
self.mean = self.add_weight(
82-
shape=keras.ops.shape(mean),
83-
initializer="zeros",
79+
self._mean = self.add_weight(
80+
shape=keras.ops.shape(self.mean),
81+
# Initializing with const tensor https://github.com/keras-team/keras/pull/20457#discussion_r1832081248
82+
initializer=keras.initializers.get(value=self.mean),
8483
dtype="float32",
8584
)
86-
self.mean.assign(mean)
87-
88-
std = self.std
89-
self.std = self.add_weight(
90-
shape=keras.ops.shape(std),
91-
initializer="ones",
85+
self._std = self.add_weight(
86+
shape=keras.ops.shape(self.std),
87+
# Initializing with const tensor https://github.com/keras-team/keras/pull/20457#discussion_r1832081248
88+
initializer=keras.initializers.get(self.std),
9289
dtype="float32",
9390
)
94-
self.std.assign(std)
91+
else:
92+
self._mean = self.mean
93+
self._std = self.std
9594

9695
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
97-
result = -0.5 * keras.ops.sum((samples - self.mean) ** 2 / self.std**2, axis=-1)
96+
result = -0.5 * keras.ops.sum((samples - self._mean) ** 2 / self.std**2, axis=-1)
9897

9998
if normalize:
10099
result += self.log_normalization_constant
@@ -103,4 +102,16 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
103102

104103
@allow_batch_size
105104
def sample(self, batch_shape: Shape) -> Tensor:
106-
return self.mean + self.std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator)
105+
return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator)
106+
107+
def get_config(self):
108+
base_config = super().get_config()
109+
110+
config = {
111+
"mean": self.mean,
112+
"std": self.std,
113+
"use_learnable_parameters": self.use_learnable_parameters,
114+
"seed_generator": self.seed_generator,
115+
}
116+
117+
return base_config | serialize(config)

bayesflow/distributions/diagonal_student_t.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
import math
54
import numpy as np
65

76
from bayesflow.types import Shape, Tensor
87
from bayesflow.utils import expand_tile
98
from bayesflow.utils.decorators import allow_batch_size
9+
from bayesflow.utils.serialization import serializable, serialize
1010

1111
from .distribution import Distribution
1212

1313

14-
@serializable(package="bayesflow.distributions")
14+
@serializable
1515
class DiagonalStudentT(Distribution):
1616
"""Implements a backend-agnostic diagonal Student-t distribution."""
1717

@@ -86,24 +86,22 @@ def build(self, input_shape: Shape) -> None:
8686
)
8787

8888
if self.use_learnable_parameters:
89-
loc = self.loc
90-
self.loc = self.add_weight(
91-
shape=keras.ops.shape(loc),
92-
initializer="zeros",
89+
self._loc = self.add_weight(
90+
shape=keras.ops.shape(self.loc),
91+
initializer=keras.initializers.get(self.loc),
9392
dtype="float32",
9493
)
95-
self.loc.assign(loc)
96-
97-
scale = self.scale
98-
self.scale = self.add_weight(
99-
shape=keras.ops.shape(scale),
100-
initializer="ones",
94+
self._scale = self.add_weight(
95+
shape=keras.ops.shape(self.scale),
96+
initializer=keras.initializers.get(self.scale),
10197
dtype="float32",
10298
)
103-
self.scale.assign(scale)
99+
else:
100+
self._loc = self.loc
101+
self._scale = self.scale
104102

105103
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
106-
mahalanobis_term = keras.ops.sum((samples - self.loc) ** 2 / self.scale**2, axis=-1)
104+
mahalanobis_term = keras.ops.sum((samples - self._loc) ** 2 / self._scale**2, axis=-1)
107105
result = -0.5 * (self.df + self.dim) * keras.ops.log1p(mahalanobis_term / self.df)
108106

109107
if normalize:
@@ -124,4 +122,17 @@ def sample(self, batch_shape: Shape) -> Tensor:
124122

125123
normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator)
126124

127-
return self.loc + self.scale * normal_samples * keras.ops.sqrt(self.df / chi2_samples)
125+
return self._loc + self._scale * normal_samples * keras.ops.sqrt(self.df / chi2_samples)
126+
127+
def get_config(self):
128+
base_config = super().get_config()
129+
130+
config = {
131+
"df": self.df,
132+
"loc": self.loc,
133+
"scale": self.scale,
134+
"use_learnable_parameters": self.use_learnable_parameters,
135+
"seed_generator": self.seed_generator,
136+
}
137+
138+
return base_config | serialize(config)

bayesflow/distributions/distribution.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from bayesflow.types import Shape, Tensor
44
from bayesflow.utils import layer_kwargs
5+
from bayesflow.utils.serialization import serializable, deserialize
56

67

8+
@serializable
79
class Distribution(keras.Layer):
810
def __init__(self, **kwargs):
911
super().__init__(**layer_kwargs(kwargs))
@@ -19,3 +21,7 @@ def sample(self, batch_shape: Shape) -> Tensor:
1921

2022
def compute_output_shape(self, input_shape: Shape) -> Shape:
2123
return keras.ops.shape(self.sample(input_shape[0:1]))
24+
25+
@classmethod
26+
def from_config(cls, config, custom_objects=None):
27+
return cls(**deserialize(config, custom_objects=custom_objects))

bayesflow/distributions/mixture.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from collections.abc import Sequence
2+
3+
import numpy as np
4+
5+
import keras
6+
from keras import ops
7+
8+
from bayesflow.types import Shape, Tensor
9+
from bayesflow.utils.decorators import allow_batch_size
10+
from bayesflow.utils.serialization import serializable, serialize
11+
from bayesflow.distributions import Distribution
12+
13+
14+
@serializable
15+
class Mixture(Distribution):
16+
"""Utility class for a backend-agnostic mixture distributions."""
17+
18+
def __init__(
19+
self,
20+
distributions: Sequence[Distribution],
21+
mixture_logits: Sequence[float] = None,
22+
trainable_mixture: bool = False,
23+
**kwargs,
24+
):
25+
"""
26+
Initializes a mixture of distributions as a latent distro.
27+
28+
Parameters
29+
----------
30+
distributions : Sequence[Distribution]
31+
A sequence of `Distribution` instances to form the mixture components.
32+
mixture_logits : Sequence[float], optional
33+
Initial unnormalized log‑weights for each component. If `None`, all
34+
components are assigned equal weight. Default is `None`.
35+
trainable_mixture : bool, optional
36+
Whether the mixture weights (`mixture_logits`) should be trainable.
37+
Default is `False`.
38+
**kwargs
39+
Additional keyword arguments passed to the base `Distribution` class.
40+
41+
Attributes
42+
----------
43+
distributions : Sequence[Distribution]
44+
The list of component distributions.
45+
mixture_logits : Tensor
46+
Trainable or fixed logits representing the mixture weights.
47+
dim : int or None
48+
Dimensionality of the output samples; set when first sampling.
49+
"""
50+
51+
super().__init__(**kwargs)
52+
53+
self.dim = None
54+
self.distributions = distributions
55+
56+
if mixture_logits is None:
57+
mixture_logits = keras.ops.ones(shape=len(distributions))
58+
59+
self.mixture_logits = mixture_logits
60+
self._mixture_logits = self.add_weight(
61+
shape=(len(distributions),),
62+
initializer=keras.initializers.Constant(value=mixture_logits),
63+
dtype="float32",
64+
trainable=trainable_mixture,
65+
)
66+
67+
self.trainable_mixture = trainable_mixture
68+
69+
@allow_batch_size
70+
def sample(self, batch_shape: Shape) -> Tensor:
71+
"""
72+
Draws samples from the mixture distribution by sampling a categorical index
73+
for each entry in `batch_shape` according to the softmax of `mixture_logits`,
74+
then draws from the corresponding component distribution.
75+
76+
Parameters
77+
----------
78+
batch_shape : Shape
79+
The desired sample batch shape (tuple of ints), not including the
80+
event dimension.
81+
82+
Returns
83+
-------
84+
samples: Tensor
85+
A tensor of shape `batch_shape + (dim,)` containing samples drawn
86+
from the mixture.
87+
"""
88+
# Will use numpy until keras adds support for N-D categorical sampling
89+
pvals = keras.ops.convert_to_numpy(keras.ops.softmax(self._mixture_logits))
90+
cat_samples = np.random.multinomial(n=1, pvals=pvals, size=batch_shape)
91+
cat_samples = cat_samples.argmax(axis=-1)
92+
93+
# Prepare array to fill and dtype to infer
94+
samples = np.zeros(batch_shape + (self.dim,))
95+
dtype = None
96+
97+
# Fill in array with vectorized sampling per component
98+
for i in range(len(self.distributions)):
99+
dist_mask = cat_samples == i
100+
dist_indices = np.where(dist_mask)
101+
num_dist_samples = np.sum(dist_mask)
102+
dist_samples = keras.ops.convert_to_numpy(self.distributions[i].sample(num_dist_samples))
103+
104+
samples[dist_indices] = dist_samples
105+
106+
dtype = dtype or keras.ops.dtype(dist_samples)
107+
108+
# Convert to keras for compatibility
109+
samples = keras.ops.convert_to_tensor(samples, dtype=dtype)
110+
111+
return samples
112+
113+
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
114+
"""
115+
Compute the log probability of given samples under the mixture.
116+
117+
For each input sample, computes the weighted log‑sum‑exp of the component
118+
log‑probabilities plus the mixture log‑weights.
119+
120+
Parameters
121+
----------
122+
samples : Tensor
123+
A tensor of samples with shape `batch_shape + (dim,)`.
124+
normalize : bool, optional
125+
If `True`, returns normalized log‑probabilities (i.e., includes the
126+
log normalization constant). Default is `True`.
127+
128+
Returns
129+
-------
130+
Tensor
131+
A tensor of shape `batch_shape`, containing the log probability of
132+
each sample under the mixture distribution.
133+
"""
134+
135+
log_prob = [distribution.log_prob(samples, normalize=normalize) for distribution in self.distributions]
136+
log_prob = ops.stack(log_prob, axis=-1)
137+
log_prob = ops.logsumexp(log_prob + ops.log_softmax(self._mixture_logits), axis=-1)
138+
return log_prob
139+
140+
def build(self, input_shape: Shape) -> None:
141+
for distribution in self.distributions:
142+
distribution.build(input_shape)
143+
144+
self.dim = input_shape[-1]
145+
146+
def get_config(self):
147+
base_config = super().get_config()
148+
149+
config = {
150+
"distributions": self.distributions,
151+
"mixture_logits": self.mixture_logits,
152+
"trainable_mixture": self.trainable_mixture,
153+
}
154+
155+
return base_config | serialize(config)

bayesflow/distributions/mixture_distribution.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

tests/test_distributions/test_diagonal_normal.py

Whitespace-only changes.

tests/test_distributions/test_diagonal_student_t.py

Whitespace-only changes.

0 commit comments

Comments
 (0)