Skip to content
Merged
17 changes: 14 additions & 3 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,21 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
init_state::Function
end

function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
function TGCNCell(ch::Pair{Int, Int};
use_bias = true,
init_weight = glorot_uniform,
init_state = zeros32,
init_bias = zeros32,
add_self_loops = false,
use_edge_weight = true,
gate_activation = sigmoid)
in_dims, out_dims = ch
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
conv = GCNConv(ch, gate_activation; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
gru = Lux.GRUCell(out_dims => out_dims;
use_bias,
init_weight = (init_weight, init_weight, init_weight),
init_bias = (init_bias, init_bias, init_bias),
init_state = init_state)
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
end

Expand Down
8 changes: 8 additions & 0 deletions GNNLux/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,12 @@
loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st))))
test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "TGCN with Custom Activations" begin
l = TGCN(3=>3, gate_activation = Lux.relu)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end
end
21 changes: 15 additions & 6 deletions GraphNeuralNetworks/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -824,11 +824,13 @@ end

Flux.@layer :noexpand TGCNCell

function TGCNCell((in, out)::Pair{Int, Int}; kws...)
function TGCNCell((in, out)::Pair{Int, Int};
gate_activation = sigmoid,
kws...)
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
dense_z = Dense(2*out => out, sigmoid)
dense_z = Dense(2*out => out, gate_activation)
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
dense_r = Dense(2*out => out, sigmoid)
dense_r = Dense(2*out => out, gate_activation)
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
dense_h = Dense(2*out => out, tanh)
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
Expand Down Expand Up @@ -858,16 +860,22 @@ function Base.show(io::IO, cell::TGCNCell)
end

"""
TGCN(args...; kws...)
TGCN(args...; gate_activation = sigmoid, kws...)

Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell.

The arguments are passed to the [`TGCNCell`](@ref) constructor.
See [`GNNRecurrence`](@ref) for more details.

# Additional Parameters

- `gate_activation`: Activation function for the gate mechanisms. Default `sigmoid`.

# Examples

```jldoctest
julia> using Flux # Ensure relu is available

julia> num_nodes, num_edges = 5, 10;

julia> d_in, d_out = 2, 3;
Expand All @@ -878,7 +886,7 @@ julia> g = rand_graph(num_nodes, num_edges);

julia> x = rand(Float32, d_in, timesteps, num_nodes);

julia> layer = TGCN(d_in => d_out)
julia> layer = TGCN(d_in => d_out, gate_activation = relu)
GNNRecurrence(
TGCNCell(2 => 3), # 126 parameters
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
Expand All @@ -889,5 +897,6 @@ julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
"""
TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...))
TGCN(args...; gate_activation = sigmoid, kws...) =
GNNRecurrence(TGCNCell(args...; gate_activation, kws...))

22 changes: 22 additions & 0 deletions GraphNeuralNetworks/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ end
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
end

@testitem "TGCN with custom activations" setup=[TemporalConvTestModule, TestModule] begin
using .TemporalConvTestModule, .TestModule
using Flux: relu, sigmoid
layer_default = TGCN(in_channel => out_channel)
layer_custom = TGCN(in_channel => out_channel, gate_activation = relu)
x = rand(Float32, in_channel, timesteps, g.num_nodes)
y_default = layer_default(g, x)
y_custom = layer_custom(g, x)
@test layer_default isa GNNRecurrence
@test layer_custom isa GNNRecurrence
@test size(y_default) == size(y_custom)
# Outputs should be different due to different activations
@test y_default != y_custom
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)

# interplay with GNNChain
model = GNNChain(TGCN(in_channel => out_channel, gate_activation=relu), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
end

@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin
using .TemporalConvTestModule, .TestModule
cell = GConvLSTMCell(in_channel => out_channel, 2)
Expand Down
Loading