@@ -824,13 +824,16 @@ end
824824
825825Flux. @layer :noexpand TGCNCell
826826
827- function TGCNCell ((in, out):: Pair{Int, Int} ; kws... )
827+ function TGCNCell ((in, out):: Pair{Int, Int} ;
828+ gate_activation = sigmoid,
829+ hidden_activation = tanh,
830+ kws... )
828831 conv_z = GNNChain (GCNConv (in => out, relu; kws... ), GCNConv (out => out; kws... ))
829- dense_z = Dense (2 * out => out, sigmoid )
832+ dense_z = Dense (2 * out => out, gate_activation )
830833 conv_r = GNNChain (GCNConv (in => out, relu; kws... ), GCNConv (out => out; kws... ))
831- dense_r = Dense (2 * out => out, sigmoid )
834+ dense_r = Dense (2 * out => out, gate_activation )
832835 conv_h = GNNChain (GCNConv (in => out, relu; kws... ), GCNConv (out => out; kws... ))
833- dense_h = Dense (2 * out => out, tanh )
836+ dense_h = Dense (2 * out => out, hidden_activation )
834837 return TGCNCell (in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
835838end
836839
@@ -858,13 +861,18 @@ function Base.show(io::IO, cell::TGCNCell)
858861end
859862
860863"""
861- TGCN(args...; kws...)
864+ TGCN(args...; gate_activation = sigmoid, hidden_activation = tanh, kws...)
862865
863866Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell.
864867
865868The arguments are passed to the [`TGCNCell`](@ref) constructor.
866869See [`GNNRecurrence`](@ref) for more details.
867870
871+ # Additional Parameters
872+
873+ - `gate_activation`: Activation function for the gate mechanisms. Default `sigmoid`.
874+ - `hidden_activation`: Activation function for the hidden state update. Default `tanh`.
875+
868876# Examples
869877
870878```jldoctest
@@ -878,7 +886,7 @@ julia> g = rand_graph(num_nodes, num_edges);
878886
879887julia> x = rand(Float32, d_in, timesteps, num_nodes);
880888
881- julia> layer = TGCN(d_in => d_out)
889+ julia> layer = TGCN(d_in => d_out, hidden_activation = relu )
882890GNNRecurrence(
883891 TGCNCell(2 => 3), # 126 parameters
884892) # Total: 18 arrays, 126 parameters, 1.469 KiB.
@@ -889,5 +897,6 @@ julia> size(y) # (d_out, timesteps, num_nodes)
889897(3, 5, 5)
890898```
891899"""
892- TGCN (args... ; kws... ) = GNNRecurrence (TGCNCell (args... ; kws... ))
900+ TGCN (args... ; gate_activation = sigmoid, hidden_activation = tanh, kws... ) =
901+ GNNRecurrence (TGCNCell (args... ; gate_activation, hidden_activation, kws... ))
893902
0 commit comments