@@ -228,7 +228,6 @@ def run_SAM(df_data, skeleton=None, device=None, **kwargs):
228228 activation_function = th .nn .LeakyReLU ,
229229 activation_argument = 0.2 , ** kwargs )
230230 kwargs ["activation_function" ] = activation_function
231-
232231 sam = sam .to (device )
233232 discriminator_sam = discriminator_sam .to (device )
234233 data = data .to (device )
@@ -269,20 +268,27 @@ def run_SAM(df_data, skeleton=None, device=None, **kwargs):
269268 # 1. Train discriminator on fake
270269 disc_output_detached = discriminator_sam (
271270 generator_output .detach ())
272- disc_output = discriminator_sam (generator_output )
273271 disc_losses .append (
274272 criterion (disc_output_detached , false_variable ))
275273
276- # 2. Train the generator :
277- gen_losses .append (criterion (disc_output , true_variable ))
278-
279274 true_output = discriminator_sam (batch )
280275 adv_loss = sum (disc_losses )/ cols + \
281276 criterion (true_output , true_variable )
282277 gen_loss = sum (gen_losses )
283278
284279 adv_loss .backward ()
285280 d_optimizer .step ()
281+ g_optimizer .zero_grad ()
282+
283+ for i in range (cols ):
284+ generator_output = th .cat ([v for c in [batch_vectors [: i ], [
285+ generated_variables [i ]],
286+ batch_vectors [i + 1 :]] for v in c ], 1 )
287+ # 1. Train discriminator on fake
288+ disc_output = discriminator_sam (generator_output )
289+
290+ # 2. Train the generator :
291+ gen_losses .append (criterion (disc_output , true_variable ))
286292
287293 # 3. Compute filter regularization
288294 filters = th .stack (
0 commit comments