55
66from bayesflow .networks import MLP
77from bayesflow .types import Tensor
8- from bayesflow .utils import (
9- logging ,
10- jvp ,
11- concatenate_valid ,
12- find_network ,
13- expand_right_as ,
14- expand_right_to ,
15- layer_kwargs ,
16- )
8+ from bayesflow .utils import logging , jvp , find_network , expand_right_as , expand_right_to , layer_kwargs , tensor_utils
179from bayesflow .utils .serialization import deserialize , serializable , serialize
1810
1911from bayesflow .networks import InferenceNetwork
@@ -83,6 +75,11 @@ def __init__(
8375 includes depth, hidden sizes, and non-linearity choices.
8476 embedding_kwargs : dict[str, any], optional, default=None
8577 Keyword arguments for the time embedding layer(s) used in the model
78+ concatenate_subnet_input: bool, optional
79+ Flag for advanced users to control whether all inputs to the subnet should be concatenated
80+ into a single vector or passed as separate arguments. If set to False, the subnet
81+ must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
82+ and optional 'conditions'. Default is True.
8683 **kwargs
8784 Additional keyword arguments passed to the parent ``InferenceNetwork`` initializer
8885 (e.g., ``name``, ``dtype``, or ``trainable``).
@@ -97,6 +94,7 @@ def __init__(
9794 self .subnet_projector = keras .layers .Dense (
9895 units = None , bias_initializer = "zeros" , kernel_initializer = "zeros" , name = "subnet_projector"
9996 )
97+ self ._concatenate_subnet_input = kwargs .get ("concatenate_subnet_input" , True )
10098
10199 weight_mlp_kwargs = weight_mlp_kwargs or {}
102100 weight_mlp_kwargs = StableConsistencyModel .WEIGHT_MLP_DEFAULT_CONFIG | weight_mlp_kwargs
@@ -107,6 +105,7 @@ def __init__(
107105 )
108106
109107 embedding_kwargs = embedding_kwargs or {}
108+ self .embedding_kwargs = embedding_kwargs
110109 self .time_emb = FourierEmbedding (** embedding_kwargs )
111110 self .time_emb_dim = self .time_emb .embed_dim
112111
@@ -124,6 +123,8 @@ def get_config(self):
124123 config = {
125124 "subnet" : self .subnet ,
126125 "sigma" : self .sigma ,
126+ "embedding_kwargs" : self .embedding_kwargs ,
127+ "concatenate_subnet_input" : self ._concatenate_subnet_input ,
127128 }
128129
129130 return base_config | serialize (config )
@@ -151,17 +152,22 @@ def build(self, xz_shape, conditions_shape=None):
151152 # construct input shape for subnet and subnet projector
152153 input_shape = list (xz_shape )
153154
154- # time vector
155- input_shape [- 1 ] += self .time_emb_dim + 1
156-
157- if conditions_shape is not None :
158- input_shape [- 1 ] += conditions_shape [- 1 ]
159-
160- input_shape = tuple (input_shape )
161-
162- self .subnet .build (input_shape )
163-
164- input_shape = self .subnet .compute_output_shape (input_shape )
155+ if self ._concatenate_subnet_input :
156+ # construct time vector
157+ input_shape [- 1 ] += self .time_emb_dim + 1
158+ if conditions_shape is not None :
159+ input_shape [- 1 ] += conditions_shape [- 1 ]
160+ input_shape = tuple (input_shape )
161+
162+ self .subnet .build (input_shape )
163+ input_shape = self .subnet .compute_output_shape (input_shape )
164+ else :
165+ # Multiple separate inputs
166+ time_shape = tuple (xz_shape [:- 1 ]) + (self .time_emb_dim + 1 ,) # same batch/sequence dims, 1 feature
167+ self .subnet .build (x_shape = xz_shape , t_shape = time_shape , conditions_shape = conditions_shape )
168+ input_shape = self .subnet .compute_output_shape (
169+ x_shape = xz_shape , t_shape = time_shape , conditions_shape = conditions_shape
170+ )
165171 self .subnet_projector .build (input_shape )
166172
167173 # input shape for time embedding
@@ -173,6 +179,35 @@ def build(self, xz_shape, conditions_shape=None):
173179 input_shape = self .weight_fn .compute_output_shape (input_shape )
174180 self .weight_fn_projector .build (input_shape )
175181
182+ def _apply_subnet (
183+ self , x : Tensor , t : Tensor , conditions : Tensor = None , training : bool = False
184+ ) -> Tensor | tuple [Tensor , Tensor , Tensor ]:
185+ """
186+ Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
187+ the time `t`, and optional conditions or by returning them separately.
188+
189+ Parameters
190+ ----------
191+ x : Tensor
192+ The parameter tensor, typically of shape (..., D), but can vary.
193+ t : Tensor
194+ The time tensor, typically of shape (..., 1).
195+ conditions : Tensor, optional
196+ The optional conditioning tensor (e.g. parameters).
197+ training : bool, optional
198+ The training mode flag, which can be used to control behavior during training.
199+
200+ Returns
201+ -------
202+ Tensor
203+ The output tensor from the subnet.
204+ """
205+ if self ._concatenate_subnet_input :
206+ xtc = tensor_utils .concatenate_valid ([x , t , conditions ], axis = - 1 )
207+ return self .subnet (xtc , training = training )
208+ else :
209+ return self .subnet (x = x , t = t , conditions = conditions , training = training )
210+
176211 def _forward (self , x : Tensor , conditions : Tensor = None , ** kwargs ) -> Tensor :
177212 # Consistency Models only learn the direction from noise distribution
178213 # to target distribution, so we cannot implement this function.
@@ -218,7 +253,6 @@ def consistency_function(
218253 t : Tensor ,
219254 conditions : Tensor = None ,
220255 training : bool = False ,
221- ** kwargs ,
222256 ) -> Tensor :
223257 """Compute consistency function at time t.
224258
@@ -235,8 +269,8 @@ def consistency_function(
235269 **kwargs : dict, optional, default: {}
236270 Additional keyword arguments passed to the inner network.
237271 """
238- xtc = concatenate_valid ([ x / self .sigma , self .time_emb (t ), conditions ], axis = - 1 )
239- f = self .subnet_projector (self . subnet ( xtc , training = training , ** kwargs ) )
272+ subnet_out = self . _apply_subnet ( x / self .sigma , self .time_emb (t ), conditions , training = training )
273+ f = self .subnet_projector (subnet_out )
240274 out = ops .cos (t ) * x - ops .sin (t ) * self .sigma * f
241275 return out
242276
@@ -273,7 +307,7 @@ def compute_metrics(
273307 r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here
274308
275309 def f_teacher (x , t ):
276- o = self .subnet ( concatenate_valid ([ x , self .time_emb (t ), conditions ], axis = - 1 ) , training = stage == "training" )
310+ o = self ._apply_subnet ( x / self . sigma , self .time_emb (t ), conditions , training = stage == "training" )
277311 return self .subnet_projector (o )
278312
279313 primals = (xt / self .sigma , t )
@@ -287,8 +321,8 @@ def f_teacher(x, t):
287321 cos_sin_dFdt = ops .stop_gradient (cos_sin_dFdt )
288322
289323 # calculate output of the network
290- xtc = concatenate_valid ([ xt / self .sigma , self .time_emb (t ), conditions ], axis = - 1 )
291- student_out = self .subnet_projector (self . subnet ( xtc , training = stage == "training" ) )
324+ subnet_out = self . _apply_subnet ( x / self .sigma , self .time_emb (t ), conditions , training = stage == "training" )
325+ student_out = self .subnet_projector (subnet_out )
292326
293327 # calculate the tangent
294328 g = - (ops .cos (t ) ** 2 ) * (self .sigma * teacher_output - dxtdt ) - r * ops .cos (t ) * ops .sin (t ) * (
0 commit comments