|
| 1 | +# Temporal Convolutional Layers for Graph Neural Networks |
| 2 | +# Implementations are found in GNNlib |
| 3 | + |
1 | 4 | function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
|
2 | 5 | y = []
|
3 | 6 | for xt in eachslice(x, dims = 2)
|
@@ -241,17 +244,7 @@ function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
|
241 | 244 | end
|
242 | 245 |
|
243 | 246 | function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
|
244 |
| - # reset gate |
245 |
| - r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h) |
246 |
| - r = Flux.sigmoid_fast(r) |
247 |
| - # update gate |
248 |
| - z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h) |
249 |
| - z = Flux.sigmoid_fast(z) |
250 |
| - # new gate |
251 |
| - h̃ = cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h) |
252 |
| - h̃ = Flux.tanh_fast(h̃) |
253 |
| - h = (1 .- z) .* h̃ .+ z .* h |
254 |
| - return h, h |
| 247 | + return gconvgrucell_frwd(cell, g, x, h) |
255 | 248 | end
|
256 | 249 |
|
257 | 250 | function Base.show(io::IO, cell::GConvGRUCell)
|
@@ -422,19 +415,7 @@ function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c))
|
422 | 415 | c = repeat(c, 1, g.num_nodes)
|
423 | 416 | end
|
424 | 417 | @assert ndims(h) == 2 && ndims(c) == 2
|
425 |
| - # input gate |
426 |
| - i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i |
427 |
| - i = Flux.sigmoid_fast(i) |
428 |
| - # forget gate |
429 |
| - f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f |
430 |
| - f = Flux.sigmoid_fast(f) |
431 |
| - # cell state |
432 |
| - c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c) |
433 |
| - # output gate |
434 |
| - o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o |
435 |
| - o = Flux.sigmoid_fast(o) |
436 |
| - h = o .* Flux.tanh_fast(c) |
437 |
| - return h, (h, c) |
| 418 | + gconvlstmcell_frwd(cell, g, x, (h, c)) |
438 | 419 | end
|
439 | 420 |
|
440 | 421 | function Base.show(io::IO, cell::GConvLSTMCell)
|
@@ -563,16 +544,7 @@ function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
|
563 | 544 | end
|
564 | 545 |
|
565 | 546 | function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
|
566 |
| - h̃ = vcat(x, h) |
567 |
| - z = cell.dconv_u(g, h̃) |
568 |
| - z = NNlib.sigmoid_fast.(z) |
569 |
| - r = cell.dconv_r(g, h̃) |
570 |
| - r = NNlib.sigmoid_fast.(r) |
571 |
| - ĥ = vcat(x, h .* r) |
572 |
| - c = cell.dconv_c(g, ĥ) |
573 |
| - c = NNlib.tanh_fast.(c) |
574 |
| - h = z.* h + (1 .- z) .* c |
575 |
| - return h, h |
| 547 | + return dcgrucell_frwd(cell, g, x, h) |
576 | 548 | end
|
577 | 549 |
|
578 | 550 | function Base.show(io::IO, cell::DCGRUCell)
|
|
700 | 672 | (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))
|
701 | 673 |
|
702 | 674 | function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state)
|
703 |
| - weight, state_lstm = cell.lstm(state.weight, state.lstm) |
704 |
| - x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) |
705 |
| - return x, (; weight, lstm = state_lstm) |
| 675 | + return evolvegcno_frwd(cell, g, x, state.weight, state.lstm) |
706 | 676 | end
|
707 | 677 |
|
708 | 678 | function Base.show(io::IO, egcno::EvolveGCNOCell)
|
@@ -845,14 +815,7 @@ function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
|
845 | 815 | end
|
846 | 816 |
|
847 | 817 | function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
|
848 |
| - z = cell.conv_z(g, x) |
849 |
| - z = cell.dense_z(vcat(z, h)) |
850 |
| - r = cell.conv_r(g, x) |
851 |
| - r = cell.dense_r(vcat(r, h)) |
852 |
| - h̃ = cell.conv_h(g, x) |
853 |
| - h̃ = cell.dense_h(vcat(h̃, r .* h)) |
854 |
| - h = (1 .- z) .* h .+ z .* h̃ |
855 |
| - return h, h |
| 818 | + return tgcncell_frwd(cell, g, x, h) |
856 | 819 | end
|
857 | 820 |
|
858 | 821 | function Base.show(io::IO, cell::TGCNCell)
|
|
0 commit comments