Skip to content

Commit 7d7c641

Browse files
authored
Update tune.py
1 parent 1a440ee commit 7d7c641

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tune.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,22 @@ def tune(cfg, graph, model, task_heads, opt, W_down, dis, cluster_id, pos_idx, n
8080

8181
with torch.no_grad():
8282
eval_h = model(feats)
83-
eval_center = eval_h.mean(dim=0)
84-
score = -dis(eval_h, eval_center, mode='global').cpu().numpy()
83+
eval_h_local = torch.mul(eval_h, head1)
84+
eval_h_cluster = torch.mul(H_init, head2)
85+
eval_h_local_diff = mean_agg.extract_H_diff(graph, eval_h_local, cluster_id, mode='local')
86+
eval_h_cluster_diff = mean_agg.extract_H_diff(graph, eval_h_cluster, cluster_id, mode='cluster')
87+
88+
eval_h_concat = torch.cat((eval_h_local_diff, eval_h_cluster_diff), dim=1)
89+
90+
eval_h_down = torch.matmul(eval_h_concat, W_down)
91+
92+
eval_center = eval_h_down.mean(dim=0)
93+
score = dis(eval_h_down, eval_center, mode='global').cpu().numpy()
8594

8695
aucroc = roc_auc_score(labels.cpu().numpy(), score)
8796
aucpr = average_precision_score(labels.cpu().numpy(), score)
8897

98+
8999
if aucroc > best_aucroc:
90100
best_aucroc = aucroc
91101
best_aucpr = aucpr

0 commit comments

Comments
 (0)