@@ -43,15 +43,9 @@ def __init__(
4343 self .register_buffer ("_supports" , supports )
4444
4545 num_matrices = len (supports ) * self ._max_diffusion_step + 1
46- input_size_gconv = (self ._num_units + input_dim ) * num_matrices
47-
48- if self ._use_gc_for_ru :
49- input_size_ru = input_size_gconv
50- else :
51- input_size_ru = self ._num_units + input_dim
52- raise NotImplementedError (
53- "Fully-connected reset and update gates not yet implemented"
54- )
46+ input_size_fc = self ._num_units + input_dim
47+ input_size_gconv = input_size_fc * num_matrices
48+ input_size_ru = input_size_gconv if self ._use_gc_for_ru else input_size_fc
5549
5650 output_size = 2 * self ._num_units
5751 self ._ru_weights = nn .Parameter (torch .empty (input_size_ru , output_size ))
@@ -85,22 +79,18 @@ def _fc(self, inputs, state, output_size, bias_start=0.0, reset=True):
8579 shape = (batch_size * self ._num_nodes , - 1 )
8680 inputs = torch .reshape (inputs , shape )
8781 state = torch .reshape (state , shape )
88- inputs_and_state = torch .cat ([inputs , state ], dim = - 1 )
89-
90- value = torch .sigmoid (torch .matmul (inputs_and_state , self ._ru_weights ))
91- value += self ._ru_biases
82+ x = torch .cat ([inputs , state ], dim = - 1 )
9283
93- return value
84+ return torch . matmul ( x , self . _ru_weights ) + self . _ru_biases
9485
9586 def _gconv (self , inputs , state , output_size , bias_start = 0.0 , reset = False ):
9687 batch_size = inputs .shape [0 ]
9788 shape = (batch_size , self ._num_nodes , - 1 )
9889 inputs = torch .reshape (inputs , shape )
9990 state = torch .reshape (state , shape )
100- inputs_and_state = torch .cat ([inputs , state ], dim = 2 )
101- input_size = inputs_and_state .size (2 )
91+ x = torch .cat ([inputs , state ], dim = 2 )
92+ input_size = x .size (2 )
10293
103- x = inputs_and_state
10494 x0 = x .permute (1 , 2 , 0 )
10595 x0 = torch .reshape (x0 , shape = [self ._num_nodes , input_size * batch_size ])
10696 x = torch .unsqueeze (x0 , 0 )
0 commit comments