-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmodel.py
More file actions
120 lines (82 loc) · 3.53 KB
/
model.py
File metadata and controls
120 lines (82 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch as th
import torch.nn as nn
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import get_laplacian
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter, Linear
from ChebnetII_pro import ChebnetII_prop
class LogReg(nn.Module):
def __init__(self, hid_dim, n_classes):
super(LogReg, self).__init__()
self.fc = nn.Linear(hid_dim, n_classes)
def forward(self, x):
ret = self.fc(x)
return ret
class Discriminator(nn.Module):
def __init__(self, dim):
super(Discriminator, self).__init__()
self.fn = nn.Bilinear(dim, dim, 1)
def forward(self, h1, h2, h3, h4, c):
c_x = c.expand_as(h1).contiguous()
# positive
sc_1 = self.fn(h2, c_x).squeeze(1)
sc_2 = self.fn(h1, c_x).squeeze(1)
# negative
sc_3 = self.fn(h4, c_x).squeeze(1)
sc_4 = self.fn(h3, c_x).squeeze(1)
logits = th.cat((sc_1, sc_2, sc_3, sc_4))
return logits
class Model(nn.Module):
def __init__(self, in_dim, out_dim, K, dprate, dropout, is_bns, act_fn):
super(Model, self).__init__()
self.encoder = ChebNetII(num_features=in_dim, hidden=out_dim, K=K, dprate=dprate, dropout=dropout, is_bns=is_bns, act_fn=act_fn)
self.disc = Discriminator(out_dim)
self.act_fn = nn.ReLU()
self.alpha = nn.Parameter(torch.tensor(0.5), requires_grad=True)
self.beta = nn.Parameter(torch.tensor(0.5), requires_grad=True)
def get_embedding(self, edge_index, feat):
h1 = self.encoder(x=feat, edge_index=edge_index, highpass=True)
h2 = self.encoder(x=feat, edge_index=edge_index, highpass=False)
h = torch.mul(self.alpha, h1) + torch.mul(self.beta, h2)
return h.detach()
def forward(self, edge_index, feat, shuf_feat):
# positive
h1 = self.encoder(x=feat, edge_index=edge_index, highpass=True)
h2 = self.encoder(x=feat, edge_index=edge_index, highpass=False)
# negative
h3 = self.encoder(x=shuf_feat, edge_index=edge_index, highpass=True)
h4 = self.encoder(x=shuf_feat, edge_index=edge_index, highpass=False)
h = torch.mul(self.alpha, h1) + torch.mul(self.beta, h2)
c = self.act_fn(torch.mean(h, dim=0))
out = self.disc(h1, h2, h3, h4, c)
return out
class ChebNetII(torch.nn.Module):
def __init__(self, num_features, hidden=512, K=10, dprate=0.50, dropout=0.50, is_bns=False, act_fn='relu'):
super(ChebNetII, self).__init__()
self.lin1 = Linear(num_features, hidden)
self.prop1 = ChebnetII_prop(K=K)
assert act_fn in ['relu', 'prelu']
self.act_fn = nn.PReLU() if act_fn == 'prelu' else nn.ReLU()
self.bn = torch.nn.BatchNorm1d(num_features, momentum=0.01)
self.is_bns = is_bns
self.dprate = dprate
self.dropout = dropout
self.reset_parameters()
def reset_parameters(self):
self.prop1.reset_parameters()
self.lin1.reset_parameters()
def forward(self, x, edge_index, highpass=True):
if self.dprate == 0.0:
x = self.prop1(x, edge_index, highpass=highpass)
else:
x = F.dropout(x, p=self.dprate, training=self.training)
x = self.prop1(x, edge_index, highpass=highpass)
x = F.dropout(x, p=self.dropout, training=self.training)
if self.is_bns:
x = self.bn(x)
x = self.lin1(x)
x = self.act_fn(x)
return x