1- import numpy as np
21import torch
2+ import math
3+ import numpy as np
34from torch .nn .parameter import Parameter
45from torch .nn import Module
6+ from graphgallery .nn .init import kaiming_uniform , zeros
57
8+ # TODO: change dtypes of trainable weights based on `floax`
69class GraphConvolution (Module ):
7- def __init__ (self , input_channels , output_channels , use_bias = False ):
10+ def __init__ (self , in_channels , out_channels , use_bias = False ):
811 super ().__init__ ()
9- self .input_channels = input_channels
10- self .output_channels = output_channels
11- self .kernel = Parameter (torch .FloatTensor ( input_channels , output_channels ))
12-
12+ self .in_channels = in_channels
13+ self .out_channels = out_channels
14+ self .kernel = Parameter (torch .Tensor ( in_channels , out_channels ))
15+
1316 if use_bias :
14- self .bias = Parameter (torch .FloatTensor ( output_channels ))
17+ self .bias = Parameter (torch .Tensor ( out_channels ))
1518 else :
1619 self .register_parameter ('bias' , None )
17-
20+
1821 self .reset_parameters ()
1922
2023 def reset_parameters (self ):
21- stdv = 1. / np .sqrt (self .kernel .size (1 ))
22- self .kernel .data .uniform_ (- stdv , stdv )
23- if self .bias is not None :
24- self .bias .data .uniform_ (- stdv , stdv )
24+ kaiming_uniform (self .kernel , a = math .sqrt (5 ))
25+ zeros (self .bias )
2526
2627 def forward (self , inputs ):
2728 x , adj = inputs
2829 h = torch .spmm (x , self .kernel )
2930 output = torch .spmm (adj , h )
30-
31+
3132 if self .bias is not None :
3233 return output + self .bias
3334 else :
3435 return output
3536
3637 def __repr__ (self ):
3738 return self .__class__ .__name__ + ' (' \
38- + str (self .input_channels ) + ' -> ' \
39- + str (self .output_channels ) + ')'
40-
41-
42- class ListModule (Module ):
43- """
44- Abstract list layer class.
45- """
46- def __init__ (self , * args ):
47- """
48- Module initializing.
49- """
50- super ().__init__ ()
51- idx = 0
52- for module in args :
53- self .add_module (str (idx ), module )
54- idx += 1
55-
56- def __getitem__ (self , idx ):
57- """
58- Getting the indexed layer.
59- """
60- if idx < 0 or idx >= len (self ._modules ):
61- raise IndexError ('index {} is out of range' .format (idx ))
62- it = iter (self ._modules .values ())
63- for i in range (idx ):
64- next (it )
65- return next (it )
66-
67- def __iter__ (self ):
68- """
69- Iterating on the layers.
70- """
71- return iter (self ._modules .values ())
72-
73- def __len__ (self ):
74- """
75- Number of layers.
76- """
77- return len (self ._modules )
39+ + str (self .in_channels ) + ' -> ' \
40+ + str (self .out_channels ) + ')'
0 commit comments