Skip to content
Merged
15 changes: 10 additions & 5 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
init_state::Function
end

function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform,
init_state = zeros32, init_bias = zeros32, add_self_loops = false,
use_edge_weight = true, act = sigmoid)
in_dims, out_dims = ch
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
conv = GCNConv(ch, act; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
end
Expand All @@ -57,7 +59,7 @@ function Base.show(io::IO, tgcn::TGCNCell)
end

"""
TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true, act = sigmoid)

Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).

Expand All @@ -76,7 +78,7 @@ Performs a layer of GCNConv to model spatial dependencies, followed by a Gated R
If `add_self_loops=true` the new weights will be set to 1.
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
Default `false`.

- `act`: Activation function used in the GCNConv layer. Default `sigmoid`.


# Examples
Expand All @@ -91,9 +93,12 @@ rng = Random.default_rng()
g = rand_graph(rng, 5, 10)
x = rand(rng, Float32, 2, 5)

# create TGCN layer
# create TGCN layer
tgcn = TGCN(2 => 6)

# create TGCN layer with custom activation
tgcn_relu = TGCN(2 => 6, act = relu)

# setup layer
ps, st = LuxCore.setup(rng, tgcn)

Expand Down
14 changes: 14 additions & 0 deletions GNNLux/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,25 @@
tx = [x for _ in 1:5]

@testset "TGCN" begin
# Test with default activation (sigmoid)
l = TGCN(3=>3)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
y1, _ = l(g, x, ps, st)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])

# Test with custom activation (relu)
l_relu = TGCN(3=>3, act = relu)
ps_relu = LuxCore.initialparameters(rng, l_relu)
st_relu = LuxCore.initialstates(rng, l_relu)
y2, _ = l_relu(g, x, ps_relu, st_relu)

# Outputs should be different with different activation functions
@test !isapprox(y1, y2, rtol=1.0f-2)

loss_relu = (x, ps) -> sum(first(l_relu(g, x, ps, st_relu)))
test_gradients(loss_relu, x, ps_relu; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "A3TGCN" begin
Expand Down
32 changes: 23 additions & 9 deletions GraphNeuralNetworks/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))


"""
TGCNCell(in => out; kws...)
TGCNCell(in => out; act = relu, kws...)

Recurrent graph convolutional cell from the paper
[T-GCN: A Temporal Graph Convolutional
Expand Down Expand Up @@ -824,12 +824,14 @@ end

Flux.@layer :noexpand TGCNCell

function TGCNCell((in, out)::Pair{Int, Int}; kws...)
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
function TGCNCell((in, out)::Pair{Int, Int};
act = relu,
kws...)
conv_z = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
dense_z = Dense(2*out => out, sigmoid)
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
conv_r = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
dense_r = Dense(2*out => out, sigmoid)
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
conv_h = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
dense_h = Dense(2*out => out, tanh)
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
end
Expand Down Expand Up @@ -858,16 +860,22 @@ function Base.show(io::IO, cell::TGCNCell)
end

"""
TGCN(args...; kws...)
TGCN(args...; act = relu, kws...)

Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell.

The arguments are passed to the [`TGCNCell`](@ref) constructor.
See [`GNNRecurrence`](@ref) for more details.

# Additional Parameters

- `act`: Activation function for the GCNConv layers. Default `relu`.

# Examples

```jldoctest
julia> using Flux # Ensure activation functions are available

julia> num_nodes, num_edges = 5, 10;

julia> d_in, d_out = 2, 3;
Expand All @@ -876,9 +884,14 @@ julia> timesteps = 5;

julia> g = rand_graph(num_nodes, num_edges);

julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> x = rand(Float32, d_in, timesteps, g.num_nodes);

julia> layer = TGCN(d_in => d_out) # Default activation (relu)
GNNRecurrence(
TGCNCell(2 => 3), # 126 parameters
) # Total: 18 arrays, 126 parameters, 1.469 KiB.

julia> layer = TGCN(d_in => d_out)
julia> layer_tanh = TGCN(d_in => d_out, act = tanh) # Custom activation
GNNRecurrence(
TGCNCell(2 => 3), # 126 parameters
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
Expand All @@ -889,5 +902,6 @@ julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
"""
TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...))
TGCN(args...; act = relu, kws...) =
GNNRecurrence(TGCNCell(args...; act, kws...))

30 changes: 30 additions & 0 deletions GraphNeuralNetworks/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ end

@testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin
using .TemporalConvTestModule, .TestModule

# Test with default activation function
cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel)
y, h = cell(g, g.x)
@test y === h
Expand All @@ -33,10 +35,25 @@ end
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH)

# Test with custom activation function
custom_activation = tanh
cell_custom = GraphNeuralNetworks.TGCNCell(in_channel => out_channel, act = custom_activation)
y_custom, h_custom = cell_custom(g, g.x)
@test y_custom === h_custom
@test size(h_custom) == (out_channel, g.num_nodes)
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol=RTOL_HIGH)
# with no initial state
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
# with initial state
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH)
end

@testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin
using .TemporalConvTestModule, .TestModule

# Test with default activation function
layer = TGCN(in_channel => out_channel)
x = rand(Float32, in_channel, timesteps, g.num_nodes)
state0 = rand(Float32, out_channel, g.num_nodes)
Expand All @@ -48,6 +65,19 @@ end
# with initial state
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH)

# Test with custom activation function
custom_activation = tanh
layer_custom = TGCN(in_channel => out_channel, act = custom_activation)
y_custom = layer_custom(g, x)
@test layer_custom isa GNNRecurrence
@test size(y_custom) == (out_channel, timesteps, g.num_nodes)
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol = RTOL_HIGH)
# with no initial state
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
# with initial state
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH)

# interplay with GNNChain
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
y = model(g, x)
Expand Down
Loading