-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
56 lines (44 loc) · 1.81 KB
/
model.py
File metadata and controls
56 lines (44 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class LabelPropagation(nn.Module):
r"""
Description
-----------
Introduced in `Learning from Labeled and Unlabeled Data with Label Propagation <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf>`_
.. math::
\mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y},
where unlabeled data is inferred by labeled data via propagation.
Parameters
----------
num_layers: int
The number of propagations.
alpha: float
The :math:`\alpha` coefficient.
"""
def __init__(self, num_layers, alpha):
super(LabelPropagation, self).__init__()
self.num_layers = num_layers
self.alpha = alpha
@torch.no_grad()
def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0., 1.)):
with g.local_scope():
if labels.dtype == torch.long:
labels = F.one_hot(labels.view(-1)).to(torch.float32)
y = labels
if mask is not None:
y = torch.zeros_like(labels)
y[mask] = labels[mask]
last = (1 - self.alpha) * y
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).to(labels.device).unsqueeze(1)
for _ in range(self.num_layers):
# Assume the graphs to be undirected
g.ndata['h'] = y * norm
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = last + self.alpha * g.ndata.pop('h') * norm
y = post_step(y)
last = (1 - self.alpha) * y
return y