We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1fcc17e commit f9a2030Copy full SHA for f9a2030
dhg/models/graphs/bgnn.py
@@ -70,6 +70,7 @@ def train_one_layer(
70
lbl_real = torch.ones(X_true.shape[0]).to(device)
71
lbl_fake = torch.zeros(X_true.shape[0]).to(device)
72
73
+ netG.train(), netD.train()
74
for _ in range(max_epoch):
75
X_real = X_true
76
X_fake = mp_func(netG(X_other))
@@ -209,6 +210,7 @@ def train_one_layer(
209
210
211
X_true, X_other = X_true.to(device), X_other.to(device)
212
213
214
215
216
X_fake = netD(mp_func(netG(X_other)))
0 commit comments