@@ -116,7 +116,7 @@ def __init__(
116116 if subnet == "mlp" :
117117 subnet_kwargs = DiffusionModel .MLP_DEFAULT_CONFIG | subnet_kwargs
118118 self .subnet = find_network (subnet , ** subnet_kwargs )
119- self ._concatenate_subnet_input = subnet_kwargs .get ("concatenate_subnet_input " , True )
119+ self ._subnet_concatenated_input = subnet_kwargs .get ("concatenated_input " , True )
120120
121121 self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" , name = "output_projector" )
122122
@@ -150,7 +150,7 @@ def get_config(self):
150150 "prediction_type" : self ._prediction_type ,
151151 "loss_type" : self ._loss_type ,
152152 "integrate_kwargs" : self .integrate_kwargs ,
153- "_concatenate_subnet_input " : self ._concatenate_subnet_input ,
153+ "subnet_concatenated_input " : self ._subnet_concatenated_input ,
154154 }
155155 return base_config | serialize (config )
156156
@@ -218,7 +218,7 @@ def prepare_subnet_input(self, xz: Tensor, log_snr: Tensor, conditions: Tensor =
218218 Tensor
219219 The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
220220 """
221- if self ._concatenate_subnet_input :
221+ if self ._subnet_concatenated_input :
222222 if conditions is None :
223223 return tensor_utils .concatenate_valid ([xz , log_snr ], axis = - 1 )
224224 else :
0 commit comments