Skip to content

Commit ee24b65

Browse files
Adding TGCN with parameters for custom activations functions in hidden and gate activations
1 parent 58fcd7d commit ee24b65

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -824,13 +824,16 @@ end
824824

825825
Flux.@layer :noexpand TGCNCell
826826

827-
function TGCNCell((in, out)::Pair{Int, Int}; kws...)
827+
function TGCNCell((in, out)::Pair{Int, Int};
828+
gate_activation = sigmoid,
829+
hidden_activation = tanh,
830+
kws...)
828831
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
829-
dense_z = Dense(2*out => out, sigmoid)
832+
dense_z = Dense(2*out => out, gate_activation)
830833
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
831-
dense_r = Dense(2*out => out, sigmoid)
834+
dense_r = Dense(2*out => out, gate_activation)
832835
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
833-
dense_h = Dense(2*out => out, tanh)
836+
dense_h = Dense(2*out => out, hidden_activation)
834837
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
835838
end
836839

@@ -858,13 +861,18 @@ function Base.show(io::IO, cell::TGCNCell)
858861
end
859862

860863
"""
861-
TGCN(args...; kws...)
864+
TGCN(args...; gate_activation = sigmoid, hidden_activation = tanh, kws...)
862865
863866
Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell.
864867
865868
The arguments are passed to the [`TGCNCell`](@ref) constructor.
866869
See [`GNNRecurrence`](@ref) for more details.
867870
871+
# Additional Parameters
872+
873+
- `gate_activation`: Activation function for the gate mechanisms. Default `sigmoid`.
874+
- `hidden_activation`: Activation function for the hidden state update. Default `tanh`.
875+
868876
# Examples
869877
870878
```jldoctest
@@ -878,7 +886,7 @@ julia> g = rand_graph(num_nodes, num_edges);
878886
879887
julia> x = rand(Float32, d_in, timesteps, num_nodes);
880888
881-
julia> layer = TGCN(d_in => d_out)
889+
julia> layer = TGCN(d_in => d_out, hidden_activation = relu)
882890
GNNRecurrence(
883891
TGCNCell(2 => 3), # 126 parameters
884892
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
@@ -889,5 +897,6 @@ julia> size(y) # (d_out, timesteps, num_nodes)
889897
(3, 5, 5)
890898
```
891899
"""
892-
TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...))
900+
TGCN(args...; gate_activation = sigmoid, hidden_activation = tanh, kws...) =
901+
GNNRecurrence(TGCNCell(args...; gate_activation, hidden_activation, kws...))
893902

GraphNeuralNetworks/test/layers/temporalconv.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ end
5555
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
5656
end
5757

58+
@testitem "TGCN with custom activations" setup=[TemporalConvTestModule, TestModule] begin
59+
using .TemporalConvTestModule, .TestModule
60+
using Flux: relu, sigmoid
61+
layer_default = TGCN(in_channel => out_channel)
62+
layer_custom = TGCN(in_channel => out_channel, gate_activation = relu, hidden_activation = relu)
63+
x = rand(Float32, in_channel, timesteps, g.num_nodes)
64+
y_default = layer_default(g, x)
65+
y_custom = layer_custom(g, x)
66+
@test layer_default isa GNNRecurrence
67+
@test layer_custom isa GNNRecurrence
68+
@test size(y_default) == size(y_custom)
69+
# Outputs should be different due to different activations
70+
@test y_default != y_custom
71+
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
72+
73+
# interplay with GNNChain
74+
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
75+
y = model(g, x)
76+
@test size(y) == (1, timesteps, g.num_nodes)
77+
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
78+
end
79+
5880
@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin
5981
using .TemporalConvTestModule, .TestModule
6082
cell = GConvLSTMCell(in_channel => out_channel, 2)

0 commit comments

Comments
 (0)