Skip to content

Commit 0780fe8

Browse files
authored
Adapt GINConv to TemporalSnapshotsGNNGraphs (#376)
* Adapt GINConv to TemporalSnapshotsGNNGraphs * Add test GINConv
1 parent 05d70de commit 0780fe8

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

src/layers/temporalconv.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,7 @@ end
186186
function Base.show(io::IO, a3tgcn::A3TGCN)
187187
print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))")
188188
end
189+
190+
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
191+
return l.(tg.snapshots, x)
192+
end

test/layers/temporalconv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
in_channel = 3
22
out_channel = 5
33
N = 4
4+
S = 5
45
T = Float32
56

67
g1 = GNNGraph(rand_graph(N,8),
78
ndata = rand(T, in_channel, N),
89
graph_type = :sparse)
910

11+
tg = TemporalSnapshotsGNNGraph([g1 for _ in 1:S])
12+
1013
@testset "TGCNCell" begin
1114
tgcn = GraphNeuralNetworks.TGCNCell(in_channel => out_channel)
1215
h, x̃ = tgcn(tgcn.state0, g1, g1.ndata.x)
@@ -29,4 +32,11 @@ end
2932
model = GNNChain(A3TGCN(in_channel => out_channel), Dense(out_channel, 1))
3033
@test size(model(g1, g1.ndata.x)) == (1, N)
3134
@test model(g1) isa GNNGraph
35+
end
36+
37+
@testset "GINConv" begin
38+
ginconv = GINConv(Dense(in_channel => out_channel),0.3)
39+
@test length(ginconv(tg, tg.ndata.x)) == S
40+
@test size(ginconv(tg, tg.ndata.x)[1]) == (out_channel, N)
41+
@test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S
3242
end

0 commit comments

Comments
 (0)