|
| 1 | +# Temporal Convolutional Layers for Graph Neural Networks |
| 2 | +# Implementations are found in GNNlib |
| 3 | + |
1 | 4 | function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T} |
2 | 5 | y = [] |
3 | 6 | for xt in eachslice(x, dims = 2) |
@@ -241,17 +244,7 @@ function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) |
241 | 244 | end |
242 | 245 |
|
243 | 246 | function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) |
244 | | - # reset gate |
245 | | - r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h) |
246 | | - r = Flux.sigmoid_fast(r) |
247 | | - # update gate |
248 | | - z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h) |
249 | | - z = Flux.sigmoid_fast(z) |
250 | | - # new gate |
251 | | - h̃ = cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h) |
252 | | - h̃ = Flux.tanh_fast(h̃) |
253 | | - h = (1 .- z) .* h̃ .+ z .* h |
254 | | - return h, h |
| 247 | + return gconvgrucell_frwd(cell, g, x, h) |
255 | 248 | end |
256 | 249 |
|
257 | 250 | function Base.show(io::IO, cell::GConvGRUCell) |
@@ -422,19 +415,7 @@ function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c)) |
422 | 415 | c = repeat(c, 1, g.num_nodes) |
423 | 416 | end |
424 | 417 | @assert ndims(h) == 2 && ndims(c) == 2 |
425 | | - # input gate |
426 | | - i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i |
427 | | - i = Flux.sigmoid_fast(i) |
428 | | - # forget gate |
429 | | - f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f |
430 | | - f = Flux.sigmoid_fast(f) |
431 | | - # cell state |
432 | | - c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c) |
433 | | - # output gate |
434 | | - o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o |
435 | | - o = Flux.sigmoid_fast(o) |
436 | | - h = o .* Flux.tanh_fast(c) |
437 | | - return h, (h, c) |
| 418 | + gconvlstmcell_frwd(cell, g, x, (h, c)) |
438 | 419 | end |
439 | 420 |
|
440 | 421 | function Base.show(io::IO, cell::GConvLSTMCell) |
@@ -563,16 +544,7 @@ function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) |
563 | 544 | end |
564 | 545 |
|
565 | 546 | function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) |
566 | | - h̃ = vcat(x, h) |
567 | | - z = cell.dconv_u(g, h̃) |
568 | | - z = NNlib.sigmoid_fast.(z) |
569 | | - r = cell.dconv_r(g, h̃) |
570 | | - r = NNlib.sigmoid_fast.(r) |
571 | | - ĥ = vcat(x, h .* r) |
572 | | - c = cell.dconv_c(g, ĥ) |
573 | | - c = NNlib.tanh_fast.(c) |
574 | | - h = z.* h + (1 .- z) .* c |
575 | | - return h, h |
| 547 | + return dcgrucell_frwd(cell, g, x, h) |
576 | 548 | end |
577 | 549 |
|
578 | 550 | function Base.show(io::IO, cell::DCGRUCell) |
|
700 | 672 | (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) |
701 | 673 |
|
702 | 674 | function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state) |
703 | | - weight, state_lstm = cell.lstm(state.weight, state.lstm) |
704 | | - x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) |
705 | | - return x, (; weight, lstm = state_lstm) |
| 675 | + return evolvegcno_frwd(cell, g, x, state.weight, state.lstm) |
706 | 676 | end |
707 | 677 |
|
708 | 678 | function Base.show(io::IO, egcno::EvolveGCNOCell) |
@@ -845,14 +815,7 @@ function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) |
845 | 815 | end |
846 | 816 |
|
847 | 817 | function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) |
848 | | - z = cell.conv_z(g, x) |
849 | | - z = cell.dense_z(vcat(z, h)) |
850 | | - r = cell.conv_r(g, x) |
851 | | - r = cell.dense_r(vcat(r, h)) |
852 | | - h̃ = cell.conv_h(g, x) |
853 | | - h̃ = cell.dense_h(vcat(h̃, r .* h)) |
854 | | - h = (1 .- z) .* h .+ z .* h̃ |
855 | | - return h, h |
| 818 | + return tgcncell_frwd(cell, g, x, h) |
856 | 819 | end |
857 | 820 |
|
858 | 821 | function Base.show(io::IO, cell::TGCNCell) |
|
0 commit comments