Skip to content

Commit 376e88e

Browse files
committed
improved naming
1 parent 1d57c76 commit 376e88e

File tree

3 files changed

+25
-29
lines changed

3 files changed

+25
-29
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
integrate,
1616
integrate_stochastic,
1717
logging,
18+
tensor_utils,
1819
)
1920
from bayesflow.utils.serialization import serialize, deserialize, serializable
2021

@@ -115,7 +116,7 @@ def __init__(
115116
if subnet == "mlp":
116117
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
117118
self.subnet = find_network(subnet, **subnet_kwargs)
118-
self._subnet_concatenated_input = subnet_kwargs.get("concatenated_input", True)
119+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
119120

120121
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
121122

@@ -149,7 +150,8 @@ def get_config(self):
149150
"prediction_type": self._prediction_type,
150151
"loss_type": self._loss_type,
151152
"integrate_kwargs": self.integrate_kwargs,
152-
"subnet_concatenated_input": self._subnet_concatenated_input,
153+
"concatenate_subnet_input": self._concatenate_subnet_input,
154+
# we do not need to store subnet_kwargs
153155
}
154156
return base_config | serialize(config)
155157

@@ -198,7 +200,7 @@ def convert_prediction_to_x(
198200
return (z + sigma_t**2 * pred) / alpha_t
199201
raise ValueError(f"Unknown prediction type {self._prediction_type}.")
200202

201-
def subnet_input(
203+
def _subnet_input(
202204
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
203205
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
204206
"""
@@ -221,11 +223,8 @@ def subnet_input(
221223
Tensor
222224
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
223225
"""
224-
if self._subnet_concatenated_input:
225-
if conditions is None:
226-
xtc = keras.ops.concatenate([xz, log_snr], axis=-1)
227-
else:
228-
xtc = keras.ops.concatenate([xz, log_snr, conditions], axis=-1)
226+
if self._concatenate_subnet_input:
227+
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
229228
return self.subnet(xtc, training=training)
230229
else:
231230
return self.subnet(xz, log_snr, conditions, training=training)
@@ -271,7 +270,9 @@ def velocity(
271270
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
272271
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
273272

274-
subnet_out = self.subnet_input(xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training)
273+
subnet_out = self._subnet_input(
274+
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
275+
)
275276
pred = self.output_projector(subnet_out, training=training)
276277

277278
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
@@ -490,7 +491,7 @@ def compute_metrics(
490491
diffused_x = alpha_t * x + sigma_t * eps_t
491492

492493
# calculate output of the network
493-
subnet_out = self.subnet_input(
494+
subnet_out = self._subnet_input(
494495
diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
495496
)
496497
pred = self.output_projector(subnet_out, training=training)

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
7+
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils
88
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from ..inference_network import InferenceNetwork
@@ -77,7 +77,7 @@ def __init__(
7777
subnet_kwargs = subnet_kwargs or {}
7878
if subnet == "mlp":
7979
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
80-
self._subnet_concatenated_input = subnet_kwargs.get("concatenated_input", True)
80+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
8181

8282
self.subnet = find_network(subnet, **subnet_kwargs)
8383
self.output_projector = keras.layers.Dense(
@@ -120,7 +120,7 @@ def get_config(self):
120120
"eps": self.eps,
121121
"s0": self.s0,
122122
"s1": self.s1,
123-
"subnet_concatenated_input": self._subnet_concatenated_input,
123+
"concatenate_subnet_input": self._concatenate_subnet_input,
124124
# we do not need to store subnet_kwargs
125125
}
126126

@@ -258,7 +258,7 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
258258
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
259259
return x
260260

261-
def subnet_input(
261+
def _subnet_input(
262262
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
263263
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
264264
"""
@@ -281,11 +281,8 @@ def subnet_input(
281281
Tensor
282282
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
283283
"""
284-
if self._subnet_concatenated_input:
285-
if conditions is None:
286-
xtc = keras.ops.concatenate([x, t], axis=-1)
287-
else:
288-
xtc = keras.ops.concatenate([x, t, conditions], axis=-1)
284+
if self._concatenate_subnet_input:
285+
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
289286
return self.subnet(xtc, training=training)
290287
else:
291288
return self.subnet(x, t, conditions, training=training)
@@ -305,7 +302,7 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
305302
Whether internal layers (e.g., dropout) should behave in train or inference mode.
306303
"""
307304

308-
subnet_out = self.subnet_input(x, t, conditions, training=training)
305+
subnet_out = self._subnet_input(x, t, conditions, training=training)
309306
f = self.output_projector(subnet_out)
310307

311308
# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
layer_kwargs,
1313
optimal_transport,
1414
weighted_mean,
15+
tensor_utils,
1516
)
1617
from bayesflow.utils.serialization import serialize, deserialize, serializable
1718
from ..inference_network import InferenceNetwork
@@ -107,7 +108,7 @@ def __init__(
107108
subnet_kwargs = subnet_kwargs or {}
108109
if subnet == "mlp":
109110
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs
110-
self._subnet_concatenated_input = subnet_kwargs.get("concatenated_input", True)
111+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
111112

112113
self.subnet = find_network(subnet, **subnet_kwargs)
113114
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
@@ -148,13 +149,13 @@ def get_config(self):
148149
"loss_fn": self.loss_fn,
149150
"integrate_kwargs": self.integrate_kwargs,
150151
"optimal_transport_kwargs": self.optimal_transport_kwargs,
152+
"concatenate_subnet_input": self._concatenate_subnet_input,
151153
# we do not need to store subnet_kwargs
152-
"subnet_concatenated_input": self._subnet_concatenated_input,
153154
}
154155

155156
return base_config | serialize(config)
156157

157-
def subnet_input(
158+
def _subnet_input(
158159
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
159160
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
160161
"""
@@ -177,11 +178,8 @@ def subnet_input(
177178
Tensor
178179
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
179180
"""
180-
if self._subnet_concatenated_input:
181-
if conditions is None:
182-
xtc = keras.ops.concatenate([x, t], axis=-1)
183-
else:
184-
xtc = keras.ops.concatenate([x, t, conditions], axis=-1)
181+
if self._concatenate_subnet_input:
182+
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
185183
return self.subnet(xtc, training=training)
186184
else:
187185
return self.subnet(x, t, conditions, training=training)
@@ -191,7 +189,7 @@ def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None,
191189
time = expand_right_as(time, xz)
192190
time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))
193191

194-
subnet_out = self.subnet_input(xz, time, conditions, training=training)
192+
subnet_out = self._subnet_input(xz, time, conditions, training=training)
195193
return self.output_projector(subnet_out, training=training)
196194

197195
def _velocity_trace(

0 commit comments

Comments
 (0)