diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index e58901774..d6036cd8f 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -33,9 +33,11 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st) init_state::Function end -function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) +function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, + init_state = zeros32, init_bias = zeros32, add_self_loops = false, + use_edge_weight = true, act = sigmoid) in_dims, out_dims = ch - conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight) + conv = GCNConv(ch, act; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight) gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state) return TGCNCell(in_dims, out_dims, conv, gru, init_state) end @@ -57,7 +59,7 @@ function Base.show(io::IO, tgcn::TGCNCell) end """ - TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) + TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true, act = sigmoid) Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). @@ -76,7 +78,7 @@ Performs a layer of GCNConv to model spatial dependencies, followed by a Gated R If `add_self_loops=true` the new weights will be set to 1. This option is ignored if the `edge_weight` is explicitly provided in the forward pass. Default `false`. - +- `act`: Activation function used in the GCNConv layer. Default `sigmoid`. # Examples @@ -91,9 +93,12 @@ rng = Random.default_rng() g = rand_graph(rng, 5, 10) x = rand(rng, Float32, 2, 5) -# create TGCN layer +# create TGCN layer tgcn = TGCN(2 => 6) +# create TGCN layer with custom activation +tgcn_relu = TGCN(2 => 6, act = relu) + # setup layer ps, st = LuxCore.setup(rng, tgcn) diff --git a/GNNLux/test/layers/temporalconv.jl b/GNNLux/test/layers/temporalconv.jl index 8de93efa7..794775f88 100644 --- a/GNNLux/test/layers/temporalconv.jl +++ b/GNNLux/test/layers/temporalconv.jl @@ -10,11 +10,25 @@ tx = [x for _ in 1:5] @testset "TGCN" begin + # Test with default activation (sigmoid) l = TGCN(3=>3) ps = LuxCore.initialparameters(rng, l) st = LuxCore.initialstates(rng, l) + y1, _ = l(g, x, ps, st) loss = (x, ps) -> sum(first(l(g, x, ps, st))) test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) + + # Test with custom activation (relu) + l_relu = TGCN(3=>3, act = relu) + ps_relu = LuxCore.initialparameters(rng, l_relu) + st_relu = LuxCore.initialstates(rng, l_relu) + y2, _ = l_relu(g, x, ps_relu, st_relu) + + # Outputs should be different with different activation functions + @test !isapprox(y1, y2, rtol=1.0f-2) + + loss_relu = (x, ps) -> sum(first(l_relu(g, x, ps, st_relu))) + test_gradients(loss_relu, x, ps_relu; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end @testset "A3TGCN" begin diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 2872529d6..5516d8015 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -758,7 +758,7 @@ EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...)) """ - TGCNCell(in => out; kws...) + TGCNCell(in => out, act = relu, kws...) Recurrent graph convolutional cell from the paper [T-GCN: A Temporal Graph Convolutional @@ -824,12 +824,14 @@ end Flux.@layer :noexpand TGCNCell -function TGCNCell((in, out)::Pair{Int, Int}; kws...) - conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) +function TGCNCell((in, out)::Pair{Int, Int}; + act = relu, + kws...) + conv_z = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...)) dense_z = Dense(2*out => out, sigmoid) - conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + conv_r = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...)) dense_r = Dense(2*out => out, sigmoid) - conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + conv_h = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...)) dense_h = Dense(2*out => out, tanh) return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h) end @@ -868,6 +870,8 @@ See [`GNNRecurrence`](@ref) for more details. # Examples ```jldoctest +julia> using Flux # Ensure activation functions are available + julia> num_nodes, num_edges = 5, 10; julia> d_in, d_out = 2, 3; @@ -876,9 +880,14 @@ julia> timesteps = 5; julia> g = rand_graph(num_nodes, num_edges); -julia> x = rand(Float32, d_in, timesteps, num_nodes); +julia> x = rand(Float32, d_in, timesteps, g.num_nodes); + +julia> layer = TGCN(d_in => d_out) # Default activation (relu) +GNNRecurrence( + TGCNCell(2 => 3), # 126 parameters +) # Total: 18 arrays, 126 parameters, 1.469 KiB. -julia> layer = TGCN(d_in => d_out) +julia> layer_tanh = TGCN(d_in => d_out, act = tanh) # Custom activation GNNRecurrence( TGCNCell(2 => 3), # 126 parameters ) # Total: 18 arrays, 126 parameters, 1.469 KiB. @@ -889,5 +898,6 @@ julia> size(y) # (d_out, timesteps, num_nodes) (3, 5, 5) ``` """ -TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...)) +TGCN(args...; kws...) = + GNNRecurrence(TGCNCell(args...; kws...)) diff --git a/GraphNeuralNetworks/test/layers/temporalconv.jl b/GraphNeuralNetworks/test/layers/temporalconv.jl index f8d96c0e2..93d6b0082 100644 --- a/GraphNeuralNetworks/test/layers/temporalconv.jl +++ b/GraphNeuralNetworks/test/layers/temporalconv.jl @@ -25,6 +25,8 @@ end @testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin using .TemporalConvTestModule, .TestModule + + # Test with default activation function cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel) y, h = cell(g, g.x) @test y === h @@ -33,10 +35,25 @@ end test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH) # with initial state test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH) + + # Test with custom activation function + custom_activation = tanh + cell_custom = GraphNeuralNetworks.TGCNCell(in_channel => out_channel, act = custom_activation) + y_custom, h_custom = cell_custom(g, g.x) + @test y_custom === h_custom + @test size(h_custom) == (out_channel, g.num_nodes) + # Test that outputs differ when using different activation functions + @test !isapprox(y, y_custom, rtol=RTOL_HIGH) + # with no initial state + test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH) + # with initial state + test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH) end @testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin using .TemporalConvTestModule, .TestModule + + # Test with default activation function layer = TGCN(in_channel => out_channel) x = rand(Float32, in_channel, timesteps, g.num_nodes) state0 = rand(Float32, out_channel, g.num_nodes) @@ -48,6 +65,19 @@ end # with initial state test_gradients(layer, g, x, state0, rtol = RTOL_HIGH) + # Test with custom activation function + custom_activation = tanh + layer_custom = TGCN(in_channel => out_channel, act = custom_activation) + y_custom = layer_custom(g, x) + @test layer_custom isa GNNRecurrence + @test size(y_custom) == (out_channel, timesteps, g.num_nodes) + # Test that outputs differ when using different activation functions + @test !isapprox(y, y_custom, rtol = RTOL_HIGH) + # with no initial state + test_gradients(layer_custom, g, x, rtol = RTOL_HIGH) + # with initial state + test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH) + # interplay with GNNChain model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) y = model(g, x)