Skip to content

Commit 4d8596b

Browse files
committed
Update student and mixture with multiple dims too
1 parent bc2bda8 commit 4d8596b

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

bayesflow/distributions/diagonal_student_t.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,19 @@ def __init__(
6363

6464
self.seed_generator = seed_generator or keras.random.SeedGenerator()
6565

66-
self.dim = None
66+
self.dims = None
6767
self._loc = None
6868
self._scale = None
6969

7070
def build(self, input_shape: Shape) -> None:
7171
if self.built:
7272
return
7373

74-
self.dim = int(input_shape[-1])
74+
self.dims = tuple(input_shape[1:])
7575

7676
# convert to tensor and broadcast if necessary
77-
self.loc = ops.cast(ops.broadcast_to(self.loc, (self.dim,)), "float32")
78-
self.scale = ops.cast(ops.broadcast_to(self.scale, (self.dim,)), "float32")
77+
self.loc = ops.cast(ops.broadcast_to(self.loc, self.dims), "float32")
78+
self.scale = ops.cast(ops.broadcast_to(self.scale, self.dims), "float32")
7979

8080
if self.trainable_parameters:
8181
self._loc = self.add_weight(
@@ -96,14 +96,14 @@ def build(self, input_shape: Shape) -> None:
9696

9797
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9898
mahalanobis_term = ops.sum((samples - self._loc) ** 2 / self._scale**2, axis=-1)
99-
result = -0.5 * (self.df + self.dim) * ops.log1p(mahalanobis_term / self.df)
99+
result = -0.5 * (self.df + sum(self.dims)) * ops.log1p(mahalanobis_term / self.df)
100100

101101
if normalize:
102102
log_normalization_constant = (
103-
-0.5 * self.dim * math.log(self.df)
104-
- 0.5 * self.dim * math.log(math.pi)
103+
-0.5 * sum(self.dims) * math.log(self.df)
104+
- 0.5 * sum(self.dims) * math.log(math.pi)
105105
- math.lgamma(0.5 * self.df)
106-
+ math.lgamma(0.5 * (self.df + self.dim))
106+
+ math.lgamma(0.5 * (self.df + sum(self.dims)))
107107
- ops.sum(keras.ops.log(self._scale))
108108
)
109109
result += log_normalization_constant
@@ -119,9 +119,10 @@ def sample(self, batch_shape: Shape) -> Tensor:
119119

120120
# The chi-quare samples need to be repeated across self.dim
121121
# since for each element of batch_shape only one sample is created.
122-
chi2_samples = expand_tile(chi2_samples, n=self.dim, axis=-1)
122+
chi2_samples = expand_tile(chi2_samples, n=sum(self.dims), axis=-1)
123+
chi2_samples = keras.ops.reshape(chi2_samples, batch_shape + self.dims)
123124

124-
normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator)
125+
normal_samples = keras.random.normal(batch_shape + self.dims, seed=self.seed_generator)
125126

126127
return self._loc + self._scale * normal_samples * ops.sqrt(self.df / chi2_samples)
127128

bayesflow/distributions/mixture.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959

6060
self.trainable_mixture = trainable_mixture
6161

62-
self.dim = None
62+
self.dims = None
6363
self._mixture_logits = None
6464

6565
@allow_batch_size
@@ -78,7 +78,7 @@ def sample(self, batch_shape: Shape) -> Tensor:
7878
Returns
7979
-------
8080
samples: Tensor
81-
A tensor of shape `batch_shape + (dim,)` containing samples drawn
81+
A tensor of shape `batch_shape + dims` containing samples drawn
8282
from the mixture.
8383
"""
8484
# Will use numpy until keras adds support for N-D categorical sampling
@@ -87,7 +87,7 @@ def sample(self, batch_shape: Shape) -> Tensor:
8787
cat_samples = cat_samples.argmax(axis=-1)
8888

8989
# Prepare array to fill and dtype to infer
90-
samples = np.zeros(batch_shape + (self.dim,))
90+
samples = np.zeros(batch_shape + self.dims)
9191
dtype = None
9292

9393
# Fill in array with vectorized sampling per component
@@ -137,7 +137,7 @@ def build(self, input_shape: Shape) -> None:
137137
if self.built:
138138
return
139139

140-
self.dim = input_shape[-1]
140+
self.dims = tuple(input_shape[1:])
141141

142142
for distribution in self.distributions:
143143
distribution.build(input_shape)

0 commit comments

Comments
 (0)