Skip to content

Commit 5f8235c

Browse files
committed
[FIX] SamV1 fix for gradient computation
1 parent 4c7b715 commit 5f8235c

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

cdt/causality/graph/SAMv1.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ def run_SAM(df_data, skeleton=None, device=None, **kwargs):
228228
activation_function=th.nn.LeakyReLU,
229229
activation_argument=0.2, **kwargs)
230230
kwargs["activation_function"] = activation_function
231-
232231
sam = sam.to(device)
233232
discriminator_sam = discriminator_sam.to(device)
234233
data = data.to(device)
@@ -269,20 +268,27 @@ def run_SAM(df_data, skeleton=None, device=None, **kwargs):
269268
# 1. Train discriminator on fake
270269
disc_output_detached = discriminator_sam(
271270
generator_output.detach())
272-
disc_output = discriminator_sam(generator_output)
273271
disc_losses.append(
274272
criterion(disc_output_detached, false_variable))
275273

276-
# 2. Train the generator :
277-
gen_losses.append(criterion(disc_output, true_variable))
278-
279274
true_output = discriminator_sam(batch)
280275
adv_loss = sum(disc_losses)/cols + \
281276
criterion(true_output, true_variable)
282277
gen_loss = sum(gen_losses)
283278

284279
adv_loss.backward()
285280
d_optimizer.step()
281+
g_optimizer.zero_grad()
282+
283+
for i in range(cols):
284+
generator_output = th.cat([v for c in [batch_vectors[: i], [
285+
generated_variables[i]],
286+
batch_vectors[i + 1:]] for v in c], 1)
287+
# 1. Train discriminator on fake
288+
disc_output = discriminator_sam(generator_output)
289+
290+
# 2. Train the generator :
291+
gen_losses.append(criterion(disc_output, true_variable))
286292

287293
# 3. Compute filter regularization
288294
filters = th.stack(

tests/scripts/test_causality_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_SAMv1():
6262

6363
if __name__ == "__main__":
6464
test_SAM()
65+
test_SAMv1()
6566
# test_directed()
6667
# test_undirected()
6768
# test_graph()

0 commit comments

Comments
 (0)