Skip to content

Commit 23d9e45

Browse files
TGCNCell
1 parent b229ab2 commit 23d9e45

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,3 +748,51 @@ julia> size(y[end]) # (d_out, num_nodes[end])
748748
```
749749
"""
750750
EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))
751+
752+
753+
754+
@concrete struct TGCNCell <: GNNLayer
755+
in::Int
756+
out::Int
757+
conv_z
758+
dense_z
759+
conv_r
760+
dense_r
761+
conv_h
762+
dense_h
763+
end
764+
765+
Flux.@layer :noexpand TGCNCell
766+
767+
function TGCNCell((in, out)::Pair{Int, Int}; kws...)
768+
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
769+
dense_z = Dense(2*out => out, sigmoid)
770+
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
771+
dense_r = Dense(2*out => out, sigmoid)
772+
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
773+
dense_h = Dense(2*out => out, tanh)
774+
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
775+
end
776+
777+
Flux.initialstates(cell::TGCNCell) = zeros_like(cell.dense_z.weight, cell.out)
778+
779+
(cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))
780+
781+
function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
782+
return cell(g, x, repeat(h, 1, g.num_nodes))
783+
end
784+
785+
function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
786+
z = cell.conv_z(g, x)
787+
z = cell.dense_z(vcat(z, h))
788+
r = cell.conv_r(g, x)
789+
r = cell.dense_r(vcat(r, h))
790+
= cell.conv_h(g, x)
791+
= cell.dense_h(vcat(h̃, r .* h))
792+
h = (1 .- z) .* h .+ z .*
793+
return h, h
794+
end
795+
796+
function Base.show(io::IO, cell::TGCNCell)
797+
print(io, "TGCNCell($(cell.in) => $(cell.out))")
798+
end

GraphNeuralNetworks/src/layers/temporalconv_old.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ end
77

88
Flux.@layer :noexpand TGCNCell
99

10-
function TGCNCell(ch::Pair{Int, Int};
10+
function TGCNCell((in, out)::Pair{Int, Int};
1111
bias::Bool = true,
1212
init = Flux.glorot_uniform,
1313
add_self_loops = false)
14-
in, out = ch
1514
conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops)
1615
gru = GRUCell(out => out)
1716
return TGCNCell(conv, gru, in, out)

0 commit comments

Comments
 (0)