Skip to content

Commit 3ea64b7

Browse files
committed
fix subnet input dimensions
1 parent 376e88e commit 3ea64b7

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils
7+
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
88
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from ..inference_network import InferenceNetwork
@@ -343,8 +343,8 @@ def compute_metrics(
343343

344344
log_p = ops.log(p)
345345
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
346-
t1 = ops.take(discretized_time, times)[..., None]
347-
t2 = ops.take(discretized_time, times + 1)[..., None]
346+
t1 = expand_right_as(ops.take(discretized_time, times), x)
347+
t2 = expand_right_as(ops.take(discretized_time, times + 1), x)
348348

349349
# generate noise vector
350350
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,17 @@ def _subnet_input(
179179
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
180180
"""
181181
if self._concatenate_subnet_input:
182+
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
182183
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
183184
return self.subnet(xtc, training=training)
184185
else:
186+
if training is False:
187+
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
185188
return self.subnet(x, t, conditions, training=training)
186189

187190
def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
188191
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))
189192
time = expand_right_as(time, xz)
190-
time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))
191193

192194
subnet_out = self._subnet_input(xz, time, conditions, training=training)
193195
return self.output_projector(subnet_out, training=training)

0 commit comments

Comments
 (0)