1
+ @concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
2
+ cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
3
+ end
4
+
5
+ function LuxCore. initialstates (rng:: AbstractRNG , r:: GNNLux.StatefulRecurrentCell )
6
+ return (cell= LuxCore. initialstates (rng, r. cell), carry= nothing )
7
+ end
8
+
9
+ function (r:: StatefulRecurrentCell )(g, x:: AbstractMatrix , ps, st:: NamedTuple )
10
+ (out, carry), st = applyrecurrentcell (r. cell, g, x, ps, st. cell, st. carry)
11
+ return out, (; cell= st, carry)
12
+ end
13
+
14
+ function (r:: StatefulRecurrentCell )(g, x:: AbstractVector , ps, st:: NamedTuple )
15
+ st, carry = st. cell, st. carry
16
+ for xᵢ in x
17
+ (out, carry), st = applyrecurrentcell (r. cell, g, xᵢ, ps, st, carry)
18
+ end
19
+ return out, (; cell= st, carry)
20
+ end
21
+
22
+ function applyrecurrentcell (l, g, x, ps, st, carry)
23
+ return Lux. apply (l, g, (x, carry), ps, st)
24
+ end
25
+
26
+ LuxCore. apply (m:: GNNContainerLayer , g, x, ps, st) = m (g, x, ps, st)
27
+
28
+ @concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)}
29
+ in_dims:: Int
30
+ out_dims:: Int
31
+ conv
32
+ gru
33
+ init_state:: Function
34
+ end
35
+
36
+ function TGCNCell (ch:: Pair{Int, Int} ; use_bias = true , init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false , use_edge_weight = true )
37
+ in_dims, out_dims = ch
38
+ conv = GCNConv (ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true )
39
+ gru = Lux. GRUCell (out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
40
+ return TGCNCell (in_dims, out_dims, conv, gru, init_state)
41
+ end
42
+
43
+ function (l:: TGCNCell )(g, (x, h), ps, st)
44
+ if h === nothing
45
+ h = l. init_state (l. out_dims, 1 )
46
+ end
47
+ x̃, stconv = l. conv (g, x, ps. conv, st. conv)
48
+ (h, (h,)), stgru = l. gru ((x̃,(h,)), ps. gru,st. gru)
49
+ return (h, h), (conv= stconv, gru= stgru)
50
+ end
51
+
52
+ LuxCore. outputsize (l:: TGCNCell ) = (l. out_dims,)
53
+ LuxCore. outputsize (l:: GNNLux.StatefulRecurrentCell ) = (l. cell. out_dims,)
54
+
55
+ function Base. show (io:: IO , tgcn:: TGCNCell )
56
+ print (io, " TGCNCell($(tgcn. in_dims) => $(tgcn. out_dims) )" )
57
+ end
58
+
59
+ TGCN (ch:: Pair{Int, Int} ; kwargs... ) = GNNLux. StatefulRecurrentCell (TGCNCell (ch; kwargs... ))
0 commit comments