Skip to content

Commit fc2941b

Browse files
Adding gate_activation=relu for test GNNChain in TGCN and fixing documentation
1 parent be10c34 commit fc2941b

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ function Base.show(io::IO, cell::TGCNCell)
860860
end
861861

862862
"""
863-
TGCN(args...; gate_activation = sigmoid, hidden_activation = tanh, kws...)
863+
TGCN(args...; gate_activation = sigmoid, kws...)
864864
865865
Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell.
866866
@@ -870,7 +870,6 @@ See [`GNNRecurrence`](@ref) for more details.
870870
# Additional Parameters
871871
872872
- `gate_activation`: Activation function for the gate mechanisms. Default `sigmoid`.
873-
- `hidden_activation`: Activation function for the hidden state update. Default `tanh`.
874873
875874
# Examples
876875
@@ -885,7 +884,7 @@ julia> g = rand_graph(num_nodes, num_edges);
885884
886885
julia> x = rand(Float32, d_in, timesteps, num_nodes);
887886
888-
julia> layer = TGCN(d_in => d_out, hidden_activation = relu)
887+
julia> layer = TGCN(d_in => d_out, gate_activation = relu)
889888
GNNRecurrence(
890889
TGCNCell(2 => 3), # 126 parameters
891890
) # Total: 18 arrays, 126 parameters, 1.469 KiB.

GraphNeuralNetworks/test/layers/temporalconv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ end
7171
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
7272

7373
# interplay with GNNChain
74-
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
74+
model = GNNChain(TGCN(in_channel => out_channel, gate_activation=relu), Dense(out_channel, 1))
7575
y = model(g, x)
7676
@test size(y) == (1, timesteps, g.num_nodes)
7777
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)

0 commit comments

Comments
 (0)