1
- @concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
1
+ @concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
2
2
cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
3
3
end
4
4
5
+ function initialstates (rng:: AbstractRNG , r:: StatefulRecurrentCell )
6
+ return (cell= initialstates (rng, r. cell), carry= nothing )
7
+ end
8
+
9
+ function (r:: StatefulRecurrentCell )(g, x, ps, st:: NamedTuple )
10
+ (out, carry), st_ = applyrecurrentcell (r. cell, g, x, ps, st. cell, st. carry)
11
+ return out, (; cell= st_, carry)
12
+ end
13
+
14
+ function applyrecurrentcell (l, g, x, ps, st, carry)
15
+ return Lux. apply (l, g, (x, carry), ps, st)
16
+ end
17
+
18
+ function applyrecurrentcell (l, g, x, ps, st, :: Nothing )
19
+ return Lux. apply (l, g, x, ps, st)
20
+ end
21
+
22
+ LuxCore. apply (m:: GNNContainerLayer , g, x, ps, st) = m (g, x, ps, st)
23
+
5
24
@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)}
6
25
in_dims:: Int
7
26
out_dims:: Int
18
37
19
38
LuxCore. outputsize (l:: TGCNCell ) = (l. out_dims,)
20
39
21
- function (l:: TGCNCell )(h, g, x, ps, st)
40
+ function (l:: TGCNCell )(g, x, h , ps, st)
22
41
conv = StatefulLuxLayer {true} (l. conv, ps. conv, _getstate (st, :conv ))
23
42
gru = StatefulLuxLayer {true} (l. gru, ps. gru, _getstate (st, :gru ))
24
- m = (; conv, gru)
25
- return GNNlib. tgcn_conv (m, h, g, x)
43
+ # m = (; conv, gru)
44
+
45
+ x̃, stconv = l. conv (g, x, ps. conv, st. conv)
46
+ (h, (h,)), st = l. gru ((x̃,(h,)), ps. gru,st. gru)
47
+ return (h, (h,)), st
26
48
end
27
49
28
50
function Base. show (io:: IO , tgcn:: TGCNCell )
29
51
print (io, " TGCNCell($(tgcn. in_dims) => $(tgcn. out_dims) )" )
30
- end
52
+ end
53
+
54
+ tgcn = StatefulRecurrentCell (TGCNCell (1 => 3 ))
0 commit comments