Skip to content

Commit 0933818

Browse files
committed
Improvement
1 parent dc39f81 commit 0933818

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

GNNLux/src/layers/temporalconv.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
1-
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
1+
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
22
cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
33
end
44

5+
function initialstates(rng::AbstractRNG, r::StatefulRecurrentCell)
6+
return (cell=initialstates(rng, r.cell), carry=nothing)
7+
end
8+
9+
function (r::StatefulRecurrentCell)(g, x, ps, st::NamedTuple)
10+
(out, carry), st_ = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry)
11+
return out, (; cell=st_, carry)
12+
end
13+
14+
function applyrecurrentcell(l, g, x, ps, st, carry)
15+
return Lux.apply(l, g, (x, carry), ps, st)
16+
end
17+
18+
function applyrecurrentcell(l, g, x, ps, st, ::Nothing)
19+
return Lux.apply(l, g, x, ps, st)
20+
end
21+
22+
LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
23+
524
@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)}
625
in_dims::Int
726
out_dims::Int
@@ -18,13 +37,18 @@ end
1837

1938
LuxCore.outputsize(l::TGCNCell) = (l.out_dims,)
2039

21-
function (l::TGCNCell)(h, g, x, ps, st)
40+
function (l::TGCNCell)(g, x, h, ps, st)
2241
conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv))
2342
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
24-
m = (; conv, gru)
25-
return GNNlib.tgcn_conv(m, h, g, x)
43+
#m = (; conv, gru)
44+
45+
x̃, stconv = l.conv(g, x, ps.conv, st.conv)
46+
(h, (h,)), st = l.gru((x̃,(h,)), ps.gru,st.gru)
47+
return (h, (h,)), st
2648
end
2749

2850
function Base.show(io::IO, tgcn::TGCNCell)
2951
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
30-
end
52+
end
53+
54+
tgcn = StatefulRecurrentCell(TGCNCell(1 =>3))

0 commit comments

Comments
 (0)