Skip to content

Commit fbbe84d

Browse files
committed
First draft
1 parent fc67808 commit fbbe84d

File tree

5 files changed

+40
-5
lines changed

5 files changed

+40
-5
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ export AGNNConv,
3434
# SGConv,
3535
# TAGConv,
3636
# TransformerConv
37-
37+
38+
include("layers/temporalconv.jl")
39+
export TGCNCell
3840

3941
end #module
4042

GNNLux/src/layers/temporalconv.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)}
2+
in_dims::Int
3+
out_dims::Int
4+
conv
5+
gru
6+
end
7+
8+
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
9+
in_dims, out_dims = ch
10+
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true)
11+
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
12+
return TGCNCell(in_dims, out_dims, conv, gru)
13+
end
14+
15+
LuxCore.outputsize(l::TGCNCell) = (l.out_dims,)
16+
17+
function (l::TGCNCell)(h, g, x)
18+
conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv))
19+
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
20+
m = (; conv, gru)
21+
return GNNlib.tgcn_conv(m, h, g, x)
22+
end
23+
24+
function Base.show(io::IO, tgcn::TGCNCell)
25+
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
26+
end

GNNlib/src/GNNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export agnn_conv,
6161
transformer_conv
6262

6363
include("layers/temporalconv.jl")
64-
export a3tgcn_conv
64+
export tgcn_conv
6565

6666
include("layers/pool.jl")
6767
export global_pool,

GNNlib/src/layers/temporalconv.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
####################### TGCN ######################################
2+
3+
function tgcn_conv(l, h, g::GNNGraph, x::AbstractArray)
4+
= l.conv(g, x)
5+
h, x̃ = l.gru(h, x̃)
6+
return h, x̃
7+
end
8+
9+
110
function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray)
211
h = a3tgcn.tgcn(g, x)
312
e = a3tgcn.dense1(h)

src/layers/temporalconv.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ function TGCNCell(ch::Pair{Int, Int};
3535
end
3636

3737
function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray)
38-
= tgcn.conv(g, x)
39-
h, x̃ = tgcn.gru(h, x̃)
40-
return h, x̃
38+
return GNNlib.tgcn_conv(tgcn, h, g, x)
4139
end
4240

4341
function Base.show(io::IO, tgcn::TGCNCell)

0 commit comments

Comments
 (0)