|  | 
|  | 1 | +# Temporal Convolutional Layers for Graph Neural Networks | 
|  | 2 | +# Implementations are found in GNNlib | 
|  | 3 | + | 
| 1 | 4 | function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T} | 
| 2 | 5 |     y = [] | 
| 3 | 6 |     for xt in eachslice(x, dims = 2) | 
| @@ -241,17 +244,7 @@ function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) | 
| 241 | 244 | end | 
| 242 | 245 | 
 | 
| 243 | 246 | function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) | 
| 244 |  | -    # reset gate | 
| 245 |  | -    r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h) | 
| 246 |  | -    r = Flux.sigmoid_fast(r) | 
| 247 |  | -    # update gate | 
| 248 |  | -    z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h) | 
| 249 |  | -    z = Flux.sigmoid_fast(z) | 
| 250 |  | -    # new gate | 
| 251 |  | -    h̃ = cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h) | 
| 252 |  | -    h̃ = Flux.tanh_fast(h̃) | 
| 253 |  | -    h = (1 .- z) .* h̃ .+ z .* h  | 
| 254 |  | -    return h, h | 
|  | 247 | +    return gconvgrucell_frwd(cell, g, x, h) | 
| 255 | 248 | end | 
| 256 | 249 | 
 | 
| 257 | 250 | function Base.show(io::IO, cell::GConvGRUCell) | 
| @@ -422,19 +415,7 @@ function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c)) | 
| 422 | 415 |         c = repeat(c, 1, g.num_nodes) | 
| 423 | 416 |     end | 
| 424 | 417 |     @assert ndims(h) == 2 && ndims(c) == 2 | 
| 425 |  | -    # input gate | 
| 426 |  | -    i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i  | 
| 427 |  | -    i = Flux.sigmoid_fast(i) | 
| 428 |  | -    # forget gate | 
| 429 |  | -    f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f | 
| 430 |  | -    f = Flux.sigmoid_fast(f) | 
| 431 |  | -    # cell state | 
| 432 |  | -    c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c) | 
| 433 |  | -    # output gate | 
| 434 |  | -    o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o | 
| 435 |  | -    o = Flux.sigmoid_fast(o) | 
| 436 |  | -    h =  o .* Flux.tanh_fast(c) | 
| 437 |  | -    return h, (h, c) | 
|  | 418 | +    gconvlstmcell_frwd(cell, g, x, (h, c)) | 
| 438 | 419 | end | 
| 439 | 420 | 
 | 
| 440 | 421 | function Base.show(io::IO, cell::GConvLSTMCell) | 
| @@ -563,16 +544,7 @@ function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) | 
| 563 | 544 | end | 
| 564 | 545 | 
 | 
| 565 | 546 | function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) | 
| 566 |  | -    h̃ = vcat(x, h) | 
| 567 |  | -    z = cell.dconv_u(g, h̃) | 
| 568 |  | -    z = NNlib.sigmoid_fast.(z) | 
| 569 |  | -    r = cell.dconv_r(g, h̃) | 
| 570 |  | -    r = NNlib.sigmoid_fast.(r) | 
| 571 |  | -    ĥ = vcat(x, h .* r) | 
| 572 |  | -    c = cell.dconv_c(g, ĥ) | 
| 573 |  | -    c = NNlib.tanh_fast.(c) | 
| 574 |  | -    h = z.* h + (1 .- z) .* c | 
| 575 |  | -    return h, h | 
|  | 547 | +    return dcgrucell_frwd(cell, g, x, h) | 
| 576 | 548 | end | 
| 577 | 549 | 
 | 
| 578 | 550 | function Base.show(io::IO, cell::DCGRUCell) | 
|  | 
| 700 | 672 | (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) | 
| 701 | 673 | 
 | 
| 702 | 674 | function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state) | 
| 703 |  | -    weight, state_lstm = cell.lstm(state.weight, state.lstm) | 
| 704 |  | -    x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) | 
| 705 |  | -    return x, (; weight, lstm = state_lstm) | 
|  | 675 | +    return evolvegcno_frwd(cell, g, x, state.weight, state.lstm) | 
| 706 | 676 | end | 
| 707 | 677 | 
 | 
| 708 | 678 | function Base.show(io::IO, egcno::EvolveGCNOCell) | 
| @@ -845,14 +815,7 @@ function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) | 
| 845 | 815 | end | 
| 846 | 816 | 
 | 
| 847 | 817 | function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) | 
| 848 |  | -    z = cell.conv_z(g, x) | 
| 849 |  | -    z = cell.dense_z(vcat(z, h)) | 
| 850 |  | -    r = cell.conv_r(g, x) | 
| 851 |  | -    r = cell.dense_r(vcat(r, h)) | 
| 852 |  | -    h̃ = cell.conv_h(g, x) | 
| 853 |  | -    h̃ = cell.dense_h(vcat(h̃, r .* h)) | 
| 854 |  | -    h = (1 .- z) .* h .+ z .* h̃ | 
| 855 |  | -    return h, h | 
|  | 818 | +    return tgcncell_frwd(cell, g, x, h) | 
| 856 | 819 | end | 
| 857 | 820 | 
 | 
| 858 | 821 | function Base.show(io::IO, cell::TGCNCell) | 
|  | 
0 commit comments