Skip to content

Commit c95e749

Browse files
committed
fix linkpred models
1 parent f391c19 commit c95e749

File tree

5 files changed

+9
-8
lines changed

5 files changed

+9
-8
lines changed

graphgallery/gallery/linkpred/pyg/gae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def model_step(self,
3535
lr=0.01,
3636
bias=False):
3737

38-
model = get_model("autoencoder.VGAE", self.backend)
38+
model = get_model("autoencoder.GAE", self.backend)
3939
model = model(self.graph.num_node_attrs,
4040
out_features=out_features,
4141
hids=hids,

graphgallery/nn/models/pyg/autoencoder/autoencoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def train_step_on_batch(self,
2020
self.train()
2121
optimizer = self.optimizer
2222
optimizer.zero_grad()
23-
x = to_device(x, device=device)
23+
x, _ = to_device(x, device=device)
2424
z = self.encode(*x)
2525
# here `out_index` maybe pos_edge_index
2626
# or (pos_edge_index, neg_edge_index)
@@ -65,7 +65,7 @@ def test_step_on_batch(self,
6565
device="cpu"):
6666
self.eval()
6767
metrics = self.metrics
68-
x = to_device(x, device=device)
68+
x, _ = to_device(x, device=device)
6969
z = self.encode(*x)
7070
pred = self.decode(z, out_index)
7171

@@ -78,7 +78,7 @@ def test_step_on_batch(self,
7878
@torch.no_grad()
7979
def predict_step_on_batch(self, x, out_index=None, device="cpu"):
8080
self.eval()
81-
x = to_device(x, device=device)
81+
x, _ = to_device(x, device=device)
8282
z = self.encode(*x)
8383
pred = self.decode(z, out_index)
8484
return pred.cpu().detach()

graphgallery/nn/models/pytorch/autoencoder/autoencoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_step_on_batch(self,
2727
x, y = to_device(x, y, device=device)
2828
z = self.encode(*x)
2929
out = self.decode(z, out_index)
30-
loss = self.compute_loss(out, y)
30+
loss, out = self.compute_loss(out, y)
3131
self.update_metrics(out, y)
3232

3333
if loss is not None:

graphgallery/nn/models/pytorch/autoencoder/vgae.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ def forward(self, x, adj):
6363
out = self.decode(z)
6464
return out
6565

66-
def compute_loss(self, out, y):
66+
def compute_loss(self, out, y, out_index=None):
67+
out = self.index_select(out, out_index=out_index)
6768
if self.training:
6869
mu = self.cache.pop('mu')
6970
logstd = self.cache.pop('logstd')
7071
kl_loss = -0.5 / mu.size(0) * torch.mean(torch.sum(1 + 2 * logstd - mu.pow(2) - logstd.exp().pow(2), dim=1))
7172
else:
7273
kl_loss = 0.
73-
return self.loss(out, y) + kl_loss
74+
return self.loss(out, y) + kl_loss, out

graphgallery/nn/models/torch_keras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_step_on_batch(self,
115115
@torch.no_grad()
116116
def predict_step_on_batch(self, x, out_index=None, device="cpu"):
117117
self.eval()
118-
x = to_device(x, device=device)
118+
x, _ = to_device(x, device=device)
119119
out = self.index_select(self(*x), out_index=out_index)
120120
return out.cpu().detach()
121121

0 commit comments

Comments
 (0)