|
9 | 9 | from . import utils |
10 | 10 | from ... import function as fn |
11 | 11 |
|
12 | | -__all__ = ['GraphConv', 'RelGraphConv'] |
| 12 | +__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv'] |
13 | 13 |
|
14 | 14 | class GraphConv(gluon.Block): |
15 | 15 | r"""Apply graph convolution over an input signal. |
@@ -74,7 +74,7 @@ def __init__(self, |
74 | 74 |
|
75 | 75 | with self.name_scope(): |
76 | 76 | self.weight = self.params.get('weight', shape=(in_feats, out_feats), |
77 | | - init=mx.init.Xavier()) |
| 77 | + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) |
78 | 78 | if bias: |
79 | 79 | self.bias = self.params.get('bias', shape=(out_feats,), |
80 | 80 | init=mx.init.Zero()) |
@@ -108,7 +108,7 @@ def forward(self, graph, feat): |
108 | 108 | graph = graph.local_var() |
109 | 109 | if self._norm: |
110 | 110 | degs = graph.in_degrees().astype('float32') |
111 | | - norm = mx.nd.power(degs, -0.5) |
| 111 | + norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) |
112 | 112 | shp = norm.shape + (1,) * (feat.ndim - 1) |
113 | 113 | norm = norm.reshape(shp).as_in_context(feat.context) |
114 | 114 | feat = feat * norm |
@@ -147,6 +147,101 @@ def __repr__(self): |
147 | 147 | summary += '\n)' |
148 | 148 | return summary |
149 | 149 |
|
| 150 | +class TAGConv(gluon.Block): |
| 151 | + r"""Apply Topology Adaptive Graph Convolutional Network |
| 152 | +
|
| 153 | + .. math:: |
| 154 | + \mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A} |
| 155 | + \mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k}, |
| 156 | +
|
| 157 | + where :math:`\mathbf{A}` denotes the adjacency matrix and |
| 158 | + :math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix. |
| 159 | +
|
| 160 | + Parameters |
| 161 | + ---------- |
| 162 | + in_feats : int |
| 163 | + Number of input features. |
| 164 | + out_feats : int |
| 165 | + Number of output features. |
| 166 | + k: int, optional |
| 167 | + Number of hops :math: `k`. (default: 2) |
| 168 | + bias: bool, optional |
| 169 | + If True, adds a learnable bias to the output. Default: ``True``. |
| 170 | + activation: callable activation function/layer or None, optional |
| 171 | + If not None, applies an activation function to the updated node features. |
| 172 | + Default: ``None``. |
| 173 | +
|
| 174 | + Attributes |
| 175 | + ---------- |
| 176 | + lin : mxnet.gluon.parameter.Parameter |
| 177 | + The learnable weight tensor. |
| 178 | + bias : mxnet.gluon.parameter.Parameter |
| 179 | + The learnable bias tensor. |
| 180 | + """ |
| 181 | + def __init__(self, |
| 182 | + in_feats, |
| 183 | + out_feats, |
| 184 | + k=2, |
| 185 | + bias=True, |
| 186 | + activation=None): |
| 187 | + super(TAGConv, self).__init__() |
| 188 | + self.out_feats = out_feats |
| 189 | + self.k = k |
| 190 | + self.bias = bias |
| 191 | + self.activation = activation |
| 192 | + self.in_feats = in_feats |
| 193 | + |
| 194 | + self.lin = self.params.get( |
| 195 | + 'weight', shape=(self.in_feats * (self.k + 1), self.out_feats), |
| 196 | + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) |
| 197 | + if self.bias: |
| 198 | + self.h_bias = self.params.get('bias', shape=(out_feats,), |
| 199 | + init=mx.init.Zero()) |
| 200 | + |
| 201 | + def forward(self, graph, feat): |
| 202 | + r"""Compute graph convolution |
| 203 | +
|
| 204 | + Parameters |
| 205 | + ---------- |
| 206 | + graph : DGLGraph |
| 207 | + The graph. |
| 208 | + feat : mxnet.NDArray |
| 209 | + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` |
| 210 | + is size of input feature, :math:`N` is the number of nodes. |
| 211 | +
|
| 212 | + Returns |
| 213 | + ------- |
| 214 | + mxnet.NDArray |
| 215 | + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` |
| 216 | + is size of output feature. |
| 217 | + """ |
| 218 | + graph = graph.local_var() |
| 219 | + |
| 220 | + degs = graph.in_degrees().astype('float32') |
| 221 | + norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) |
| 222 | + shp = norm.shape + (1,) * (feat.ndim - 1) |
| 223 | + norm = norm.reshape(shp).as_in_context(feat.context) |
| 224 | + |
| 225 | + rst = feat |
| 226 | + for _ in range(self.k): |
| 227 | + rst = rst * norm |
| 228 | + graph.ndata['h'] = rst |
| 229 | + |
| 230 | + graph.update_all(fn.copy_src(src='h', out='m'), |
| 231 | + fn.sum(msg='m', out='h')) |
| 232 | + rst = graph.ndata['h'] |
| 233 | + rst = rst * norm |
| 234 | + feat = mx.nd.concat(feat, rst, dim=-1) |
| 235 | + |
| 236 | + rst = mx.nd.dot(feat, self.lin.data(feat.context)) |
| 237 | + if self.bias is not None: |
| 238 | + rst = rst + self.h_bias.data(rst.context) |
| 239 | + |
| 240 | + if self.activation is not None: |
| 241 | + rst = self.activation(rst) |
| 242 | + |
| 243 | + return rst |
| 244 | + |
150 | 245 | class RelGraphConv(gluon.Block): |
151 | 246 | r"""Relational graph convolution layer. |
152 | 247 |
|
|
0 commit comments