44import numpy as np
55
66from bayesflow .types import Tensor
7- from bayesflow .utils import find_network , layer_kwargs , weighted_mean
7+ from bayesflow .utils import find_network , layer_kwargs , weighted_mean , tensor_utils , expand_right_as
88from bayesflow .utils .serialization import deserialize , serializable , serialize
99
1010from ..inference_network import InferenceNetwork
@@ -67,6 +67,11 @@ def __init__(
6767 Final number of discretization steps
6868 subnet_kwargs: dict[str, any], optional
6969 Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
70+ concatenate_subnet_input: bool, optional
71+ Flag for advanced users to control whether all inputs to the subnet should be concatenated
72+ into a single vector or passed as separate arguments. If set to False, the subnet
73+ must accept three separate inputs: 'x' (noisy parameters), 't' (time),
74+ and optional 'conditions'. Default is True.
7075 **kwargs : dict, optional, default: {}
7176 Additional keyword arguments
7277 """
@@ -77,6 +82,7 @@ def __init__(
7782 subnet_kwargs = subnet_kwargs or {}
7883 if subnet == "mlp" :
7984 subnet_kwargs = ConsistencyModel .MLP_DEFAULT_CONFIG | subnet_kwargs
85+ self ._concatenate_subnet_input = kwargs .get ("concatenate_subnet_input" , True )
8086
8187 self .subnet = find_network (subnet , ** subnet_kwargs )
8288 self .output_projector = keras .layers .Dense (
@@ -119,6 +125,7 @@ def get_config(self):
119125 "eps" : self .eps ,
120126 "s0" : self .s0 ,
121127 "s1" : self .s1 ,
128+ "concatenate_subnet_input" : self ._concatenate_subnet_input ,
122129 # we do not need to store subnet_kwargs
123130 }
124131
@@ -161,18 +168,23 @@ def build(self, xz_shape, conditions_shape=None):
161168
162169 input_shape = list (xz_shape )
163170
164- # time vector
165- input_shape [- 1 ] += 1
171+ if self ._concatenate_subnet_input :
172+ # construct time vector
173+ input_shape [- 1 ] += 1
174+ if conditions_shape is not None :
175+ input_shape [- 1 ] += conditions_shape [- 1 ]
176+ input_shape = tuple (input_shape )
166177
167- if conditions_shape is not None :
168- input_shape [- 1 ] += conditions_shape [- 1 ]
169-
170- input_shape = tuple (input_shape )
171-
172- self .subnet .build (input_shape )
173-
174- input_shape = self .subnet .compute_output_shape (input_shape )
175- self .output_projector .build (input_shape )
178+ self .subnet .build (input_shape )
179+ out_shape = self .subnet .compute_output_shape (input_shape )
180+ else :
181+ # Multiple separate inputs
182+ time_shape = tuple (xz_shape [:- 1 ]) + (1 ,) # same batch/sequence dims, 1 feature
183+ self .subnet .build (x_shape = xz_shape , t_shape = time_shape , conditions_shape = conditions_shape )
184+ out_shape = self .subnet .compute_output_shape (
185+ x_shape = xz_shape , t_shape = time_shape , conditions_shape = conditions_shape
186+ )
187+ self .output_projector .build (out_shape )
176188
177189 # Choose coefficient according to [2] Section 3.3
178190 self .c_huber = 0.00054 * ops .sqrt (xz_shape [- 1 ])
@@ -256,6 +268,35 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
256268 x = self .consistency_function (x_n , t , conditions = conditions , training = training )
257269 return x
258270
271+ def _apply_subnet (
272+ self , x : Tensor , t : Tensor , conditions : Tensor = None , training : bool = False
273+ ) -> Tensor | tuple [Tensor , Tensor , Tensor ]:
274+ """
275+ Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
276+ the time `t`, and optional conditions or by returning them separately.
277+
278+ Parameters
279+ ----------
280+ x : Tensor
281+ The parameter tensor, typically of shape (..., D), but can vary.
282+ t : Tensor
283+ The time tensor, typically of shape (..., 1).
284+ conditions : Tensor, optional
285+ The optional conditioning tensor (e.g. parameters).
286+ training : bool, optional
287+ The training mode flag, which can be used to control behavior during training.
288+
289+ Returns
290+ -------
291+ Tensor
292+ The output tensor from the subnet.
293+ """
294+ if self ._concatenate_subnet_input :
295+ xtc = tensor_utils .concatenate_valid ([x , t , conditions ], axis = - 1 )
296+ return self .subnet (xtc , training = training )
297+ else :
298+ return self .subnet (x = x , t = t , conditions = conditions , training = training )
299+
259300 def consistency_function (self , x : Tensor , t : Tensor , conditions : Tensor = None , training : bool = False ) -> Tensor :
260301 """Compute consistency function.
261302
@@ -271,12 +312,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
271312 Whether internal layers (e.g., dropout) should behave in train or inference mode.
272313 """
273314
274- if conditions is not None :
275- xtc = ops .concatenate ([x , t , conditions ], axis = - 1 )
276- else :
277- xtc = ops .concatenate ([x , t ], axis = - 1 )
278-
279- f = self .output_projector (self .subnet (xtc , training = training ))
315+ subnet_out = self ._apply_subnet (x , t , conditions , training = training )
316+ f = self .output_projector (subnet_out )
280317
281318 # Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
282319 # Thus, we can do a cross product with the time vector which is (batch_size, 1) for
@@ -316,8 +353,8 @@ def compute_metrics(
316353
317354 log_p = ops .log (p )
318355 times = keras .random .categorical (ops .expand_dims (log_p , 0 ), ops .shape (x )[0 ], seed = self .seed_generator )[0 ]
319- t1 = ops .take (discretized_time , times )[..., None ]
320- t2 = ops .take (discretized_time , times + 1 )[..., None ]
356+ t1 = expand_right_as ( ops .take (discretized_time , times ), x )
357+ t2 = expand_right_as ( ops .take (discretized_time , times + 1 ), x )
321358
322359 # generate noise vector
323360 noise = keras .random .normal (keras .ops .shape (x ), dtype = keras .ops .dtype (x ), seed = self .seed_generator )
0 commit comments