@@ -748,3 +748,51 @@ julia> size(y[end]) # (d_out, num_nodes[end])
748748``` 
749749""" 
750750EvolveGCNO (args... ; kws... ) =  GNNRecurrence (EvolveGCNOCell (args... ; kws... ))
751+ 
752+ 
753+ 
754+ @concrete  struct  TGCNCell <:  GNNLayer 
755+     in:: Int 
756+     out:: Int 
757+     conv_z
758+     dense_z
759+     conv_r
760+     dense_r
761+     conv_h
762+     dense_h
763+ end 
764+ 
765+ Flux. @layer  :noexpand  TGCNCell
766+ 
767+ function  TGCNCell ((in, out):: Pair{Int, Int} ; kws... )
768+     conv_z =  GNNChain (GCNConv (in =>  out, relu; kws... ), GCNConv (out =>  out; kws... ))
769+     dense_z =  Dense (2 * out =>  out, sigmoid)
770+     conv_r =  GNNChain (GCNConv (in =>  out, relu; kws... ), GCNConv (out =>  out; kws... ))
771+     dense_r =  Dense (2 * out =>  out, sigmoid)
772+     conv_h =  GNNChain (GCNConv (in =>  out, relu; kws... ), GCNConv (out =>  out; kws... ))
773+     dense_h =  Dense (2 * out =>  out, tanh)
774+     return  TGCNCell (in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
775+ end 
776+ 
777+ Flux. initialstates (cell:: TGCNCell ) =  zeros_like (cell. dense_z. weight, cell. out)
778+ 
779+ (cell:: TGCNCell )(g:: GNNGraph , x:: AbstractMatrix ) =  cell (g, x, initialstates (cell))
780+ 
781+ function  (cell:: TGCNCell )(g:: GNNGraph , x:: AbstractMatrix , h:: AbstractVector )
782+     return  cell (g, x, repeat (h, 1 , g. num_nodes))
783+ end 
784+ 
785+ function  (cell:: TGCNCell )(g:: GNNGraph , x:: AbstractMatrix , h:: AbstractMatrix )
786+     z =  cell. conv_z (g, x)
787+     z =  cell. dense_z (vcat (z, h))
788+     r =  cell. conv_r (g, x)
789+     r =  cell. dense_r (vcat (r, h))
790+     h̃ =  cell. conv_h (g, x)
791+     h̃ =  cell. dense_h (vcat (h̃, r .*  h))
792+     h =  (1  .-  z) .*  h .+  z .*  h̃
793+     return  h, h
794+ end 
795+ 
796+ function  Base. show (io:: IO , cell:: TGCNCell )
797+     print (io, " TGCNCell($(cell. in)  => $(cell. out) )" 
798+ end 
0 commit comments