@@ -116,6 +116,7 @@ def __init__(
116116 if subnet == "mlp" :
117117 subnet_kwargs = DiffusionModel .MLP_DEFAULT_CONFIG | subnet_kwargs
118118 self .subnet = find_network (subnet , ** subnet_kwargs )
119+ self ._concatenate_subnet_input = subnet_kwargs .get ("concatenate_subnet_input" , True )
119120
120121 self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" , name = "output_projector" )
121122
@@ -149,6 +150,7 @@ def get_config(self):
149150 "prediction_type" : self ._prediction_type ,
150151 "loss_type" : self ._loss_type ,
151152 "integrate_kwargs" : self .integrate_kwargs ,
153+ "_concatenate_subnet_input" : self ._concatenate_subnet_input ,
152154 }
153155 return base_config | serialize (config )
154156
@@ -197,6 +199,33 @@ def convert_prediction_to_x(
197199 return (z + sigma_t ** 2 * pred ) / alpha_t
198200 raise ValueError (f"Unknown prediction type { self ._prediction_type } ." )
199201
202+ def prepare_subnet_input (self , xz : Tensor , log_snr : Tensor , conditions : Tensor = None ) -> Tensor :
203+ """
204+ Prepares the input for the subnet either by concatenating the latent variable `xz`,
205+ the log signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.
206+
207+ Parameters
208+ ----------
209+ xz : Tensor
210+ The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
211+ log_snr : Tensor
212+ The log signal-to-noise ratio tensor, typically of shape (..., 1).
213+ conditions : Tensor, optional
214+ The optional conditioning tensor (e.g. parameters).
215+
216+ Returns
217+ -------
218+ Tensor
219+ The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
220+ """
221+ if self ._concatenate_subnet_input :
222+ if conditions is None :
223+ return tensor_utils .concatenate_valid ([xz , log_snr ], axis = - 1 )
224+ else :
225+ return tensor_utils .concatenate_valid ([xz , log_snr , conditions ], axis = - 1 )
226+ else :
227+ return xz , log_snr , conditions
228+
200229 def velocity (
201230 self ,
202231 xz : Tensor ,
@@ -221,7 +250,7 @@ def velocity(
221250 If True, computes the velocity for the stochastic formulation (SDE).
222251 If False, uses the deterministic formulation (ODE).
223252 conditions : Tensor, optional
224- Optional conditional inputs to the network, such as conditioning variables
253+ Conditional inputs to the network, such as conditioning variables
225254 or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
226255 training : bool, optional
227256 Whether the model is in training mode. Affects behavior of dropout, batch norm,
@@ -238,11 +267,7 @@ def velocity(
238267 log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
239268 alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
240269
241- if conditions is None :
242- xtc = tensor_utils .concatenate_valid ([xz , self ._transform_log_snr (log_snr_t )], axis = - 1 )
243- else :
244- xtc = tensor_utils .concatenate_valid ([xz , self ._transform_log_snr (log_snr_t ), conditions ], axis = - 1 )
245-
270+ xtc = self .prepare_subnet_input (xz , self ._transform_log_snr (log_snr_t ), conditions = conditions )
246271 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
247272
248273 x_pred = self .convert_prediction_to_x (pred = pred , z = xz , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t )
@@ -461,10 +486,7 @@ def compute_metrics(
461486 diffused_x = alpha_t * x + sigma_t * eps_t
462487
463488 # calculate output of the network
464- if conditions is None :
465- xtc = tensor_utils .concatenate_valid ([diffused_x , self ._transform_log_snr (log_snr_t )], axis = - 1 )
466- else :
467- xtc = tensor_utils .concatenate_valid ([diffused_x , self ._transform_log_snr (log_snr_t ), conditions ], axis = - 1 )
489+ xtc = self .prepare_subnet_input (diffused_x , self ._transform_log_snr (log_snr_t ), conditions = conditions )
468490 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
469491
470492 x_pred = self .convert_prediction_to_x (
0 commit comments