@@ -183,21 +183,29 @@ def forward(self, targets, condition, **kwargs):
183183 # Add noise to target if using SoftFlow, use explicitly
184184 # not in call(), since methods are public
185185 if self .soft_flow and condition is not None :
186+ # Extract shapes of tensors
187+ target_shape = tf .shape (targets )
188+ condition_shape = tf .shape (condition )
189+
186190 # Needs to be concatinable with condition
187- shape_scale = (
188- (condition .shape [0 ], 1 ) if len (condition .shape ) == 2 else (condition .shape [0 ], condition .shape [1 ], 1 )
189- )
191+ if len (condition_shape ) == 2 :
192+ shape_scale = (condition_shape [0 ], 1 )
193+ else :
194+ shape_scale = (condition_shape [0 ], condition_shape [1 ], 1 )
195+
190196 # Case training mode
191197 if kwargs .get ("training" ):
192198 noise_scale = tf .random .uniform (shape = shape_scale , minval = self .soft_low , maxval = self .soft_high )
193199 # Case inference mode
194200 else :
195201 noise_scale = tf .zeros (shape = shape_scale ) + self .soft_low
202+
196203 # Perturb data with noise (will broadcast to all dimensions)
197- if len (shape_scale ) == 2 and len (targets . shape ) == 3 :
198- targets += tf .expand_dims (noise_scale , axis = 1 ) * tf .random .normal (shape = targets . shape )
204+ if len (shape_scale ) == 2 and len (target_shape ) == 3 :
205+ targets += tf .expand_dims (noise_scale , axis = 1 ) * tf .random .normal (shape = target_shape )
199206 else :
200- targets += noise_scale * tf .random .normal (shape = targets .shape )
207+ targets += noise_scale * tf .random .normal (shape = target_shape )
208+
201209 # Augment condition with noise scale variate
202210 condition = tf .concat ((condition , noise_scale ), axis = - 1 )
203211
0 commit comments