Skip to content

Commit be10c34

Browse files
Only uploading Gate activation and uploading Lux TGCN custom non linear functions
1 parent ee24b65 commit be10c34

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

GNNLux/src/layers/temporalconv.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,21 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
3333
init_state::Function
3434
end
3535

36-
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
36+
function TGCNCell(ch::Pair{Int, Int};
37+
use_bias = true,
38+
init_weight = glorot_uniform,
39+
init_state = zeros32,
40+
init_bias = zeros32,
41+
add_self_loops = false,
42+
use_edge_weight = true,
43+
gate_activation = sigmoid)
3744
in_dims, out_dims = ch
38-
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
39-
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
45+
conv = GCNConv(ch, gate_activation; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
46+
gru = Lux.GRUCell(out_dims => out_dims;
47+
use_bias,
48+
init_weight = (init_weight, init_weight, init_weight),
49+
init_bias = (init_bias, init_bias, init_bias),
50+
init_state = init_state)
4051
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
4152
end
4253

GNNLux/test/layers/temporalconv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,12 @@
5656
loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st))))
5757
test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
5858
end
59+
60+
@testset "TGCN with Custom Activations" begin
61+
l = TGCN(3=>3, gate_activation = Lux.relu)
62+
ps = LuxCore.initialparameters(rng, l)
63+
st = LuxCore.initialstates(rng, l)
64+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
65+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
66+
end
5967
end

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -826,14 +826,13 @@ Flux.@layer :noexpand TGCNCell
826826

827827
function TGCNCell((in, out)::Pair{Int, Int};
828828
gate_activation = sigmoid,
829-
hidden_activation = tanh,
830829
kws...)
831830
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
832831
dense_z = Dense(2*out => out, gate_activation)
833832
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
834833
dense_r = Dense(2*out => out, gate_activation)
835834
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
836-
dense_h = Dense(2*out => out, hidden_activation)
835+
dense_h = Dense(2*out => out, tanh)
837836
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
838837
end
839838

@@ -897,6 +896,6 @@ julia> size(y) # (d_out, timesteps, num_nodes)
897896
(3, 5, 5)
898897
```
899898
"""
900-
TGCN(args...; gate_activation = sigmoid, hidden_activation = tanh, kws...) =
901-
GNNRecurrence(TGCNCell(args...; gate_activation, hidden_activation, kws...))
899+
TGCN(args...; gate_activation = sigmoid, kws...) =
900+
GNNRecurrence(TGCNCell(args...; gate_activation, kws...))
902901

GraphNeuralNetworks/test/layers/temporalconv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
using .TemporalConvTestModule, .TestModule
6060
using Flux: relu, sigmoid
6161
layer_default = TGCN(in_channel => out_channel)
62-
layer_custom = TGCN(in_channel => out_channel, gate_activation = relu, hidden_activation = relu)
62+
layer_custom = TGCN(in_channel => out_channel, gate_activation = relu)
6363
x = rand(Float32, in_channel, timesteps, g.num_nodes)
6464
y_default = layer_default(g, x)
6565
y_custom = layer_custom(g, x)

0 commit comments

Comments
 (0)