Skip to content

Commit d3cbab9

Browse files
refacto forward in GNNlib
1 parent f59d3ec commit d3cbab9

File tree

3 files changed

+77
-46
lines changed

3 files changed

+77
-46
lines changed

GNNlib/src/GNNlib.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ export agnn_conv,
6161
transformer_conv
6262

6363
include("layers/temporalconv.jl")
64-
export tgcn_conv
64+
export a3tgcn_conv,
65+
dcgrucell_frwd,
66+
evolvegcnocell_frwd
67+
gconv_grucell_frwd,
68+
gconv_lstmcell_frwd,
69+
tgcn_frwd
6570

6671
include("layers/pool.jl")
6772
export global_pool,

GNNlib/src/layers/temporalconv.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,66 @@ function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray)
1010
return c
1111
end
1212

13+
14+
function gconvgrucell_frwd(cell, g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
15+
# reset gate
16+
r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h)
17+
r = NNlib.sigmoid_fast(r)
18+
# update gate
19+
z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h)
20+
z = NNlib.sigmoid_fast(z)
21+
# new gate
22+
= cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h)
23+
= NNlib.tanh_fast(h̃)
24+
h = (1 .- z) .*.+ z .* h
25+
return h, h
26+
end
27+
28+
29+
function gconvlstmcell_frwd(cell, g::GNNGraph, x::AbstractMatrix, (h, c))
30+
# input gate
31+
i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i
32+
i = NNlib.sigmoid_fast(i)
33+
# forget gate
34+
f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f
35+
f = NNlib.sigmoid_fast(f)
36+
# cell state
37+
c = f .* c .+ i .* NNlib.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c)
38+
# output gate
39+
o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o
40+
o = NNlib.sigmoid_fast(o)
41+
h = o .* NNlib.tanh_fast(c)
42+
return h, (h, c)
43+
end
44+
45+
function dcgrucell_frwd(cell, g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
46+
= vcat(x, h)
47+
z = cell.dconv_u(g, h̃)
48+
z = NNlib.sigmoid_fast.(z)
49+
r = cell.dconv_r(g, h̃)
50+
r = NNlib.sigmoid_fast.(r)
51+
= vcat(x, h .* r)
52+
c = cell.dconv_c(g, ĥ)
53+
c = NNlib.tanh_fast.(c)
54+
h = z.* h + (1 .- z) .* c
55+
return h, h
56+
end
57+
58+
59+
function evolvegcnocell_frwd(cell, g::GNNGraph, x::AbstractMatrix, state)
60+
weight, state_lstm = cell.lstm(state.weight, state.lstm)
61+
x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in)))
62+
return x, (; weight, lstm = state_lstm)
63+
end
64+
65+
66+
function tgcncell_frwd(cell, g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
67+
z = cell.conv_z(g, x)
68+
z = cell.dense_z(vcat(z, h))
69+
r = cell.conv_r(g, x)
70+
r = cell.dense_r(vcat(r, h))
71+
= cell.conv_h(g, x)
72+
= cell.dense_h(vcat(h̃, r .* h))
73+
h = (1 .- z) .* h .+ z .*
74+
return h, h
75+
end

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Temporal Convolutional Layers for Graph Neural Networks
2+
# Implementations are found in GNNlib
3+
14
function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
25
y = []
36
for xt in eachslice(x, dims = 2)
@@ -241,17 +244,7 @@ function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
241244
end
242245

243246
function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
244-
# reset gate
245-
r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h)
246-
r = Flux.sigmoid_fast(r)
247-
# update gate
248-
z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h)
249-
z = Flux.sigmoid_fast(z)
250-
# new gate
251-
= cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h)
252-
= Flux.tanh_fast(h̃)
253-
h = (1 .- z) .*.+ z .* h
254-
return h, h
247+
return gconvgrucell_frwd(cell, g, x, h)
255248
end
256249

257250
function Base.show(io::IO, cell::GConvGRUCell)
@@ -422,19 +415,7 @@ function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c))
422415
c = repeat(c, 1, g.num_nodes)
423416
end
424417
@assert ndims(h) == 2 && ndims(c) == 2
425-
# input gate
426-
i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i
427-
i = Flux.sigmoid_fast(i)
428-
# forget gate
429-
f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f
430-
f = Flux.sigmoid_fast(f)
431-
# cell state
432-
c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c)
433-
# output gate
434-
o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o
435-
o = Flux.sigmoid_fast(o)
436-
h = o .* Flux.tanh_fast(c)
437-
return h, (h, c)
418+
gconvlstmcell_frwd(cell, g, x, (h, c))
438419
end
439420

440421
function Base.show(io::IO, cell::GConvLSTMCell)
@@ -563,16 +544,7 @@ function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
563544
end
564545

565546
function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
566-
= vcat(x, h)
567-
z = cell.dconv_u(g, h̃)
568-
z = NNlib.sigmoid_fast.(z)
569-
r = cell.dconv_r(g, h̃)
570-
r = NNlib.sigmoid_fast.(r)
571-
= vcat(x, h .* r)
572-
c = cell.dconv_c(g, ĥ)
573-
c = NNlib.tanh_fast.(c)
574-
h = z.* h + (1 .- z) .* c
575-
return h, h
547+
return dcgrucell_frwd(cell, g, x, h)
576548
end
577549

578550
function Base.show(io::IO, cell::DCGRUCell)
@@ -700,9 +672,7 @@ end
700672
(cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))
701673

702674
function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state)
703-
weight, state_lstm = cell.lstm(state.weight, state.lstm)
704-
x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in)))
705-
return x, (; weight, lstm = state_lstm)
675+
return evolvegcno_frwd(cell, g, x, state.weight, state.lstm)
706676
end
707677

708678
function Base.show(io::IO, egcno::EvolveGCNOCell)
@@ -845,14 +815,7 @@ function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
845815
end
846816

847817
function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
848-
z = cell.conv_z(g, x)
849-
z = cell.dense_z(vcat(z, h))
850-
r = cell.conv_r(g, x)
851-
r = cell.dense_r(vcat(r, h))
852-
= cell.conv_h(g, x)
853-
= cell.dense_h(vcat(h̃, r .* h))
854-
h = (1 .- z) .* h .+ z .*
855-
return h, h
818+
return tgcncell_frwd(cell, g, x, h)
856819
end
857820

858821
function Base.show(io::IO, cell::TGCNCell)

0 commit comments

Comments
 (0)