Skip to content

Commit f9a2030

Browse files
committed
add BGNN-Adv and BGNN-MLP methods on bipartite graph
1 parent 1fcc17e commit f9a2030

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

dhg/models/graphs/bgnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def train_one_layer(
7070
lbl_real = torch.ones(X_true.shape[0]).to(device)
7171
lbl_fake = torch.zeros(X_true.shape[0]).to(device)
7272

73+
netG.train(), netD.train()
7374
for _ in range(max_epoch):
7475
X_real = X_true
7576
X_fake = mp_func(netG(X_other))
@@ -209,6 +210,7 @@ def train_one_layer(
209210

210211
X_true, X_other = X_true.to(device), X_other.to(device)
211212

213+
netG.train(), netD.train()
212214
for _ in range(max_epoch):
213215
X_real = X_true
214216
X_fake = netD(mp_func(netG(X_other)))

0 commit comments

Comments
 (0)