Skip to content

Commit 884b473

Browse files
authored
Adapt 4 convolutions to TemporalSnapshotsGNNGraphs (#392)
* Adapt 4 conv * Add test 4 convs * Add checks in the docs
1 parent 4e14f67 commit 884b473

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

docs/src/api/conv.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ The table below lists all graph convolutional layers implemented in the *GraphNe
2525
| [`GATConv`](@ref) | | || | |
2626
| [`GATv2Conv`](@ref) | | || | |
2727
| [`GatedGraphConv`](@ref) || | | | |
28-
| [`GCNConv`](@ref) ||| | | |
28+
| [`GCNConv`](@ref) ||| | | |
2929
| [`GINConv`](@ref) || | | ||
3030
| [`GMMConv`](@ref) | | || | |
31-
| [`GraphConv`](@ref) || | || |
31+
| [`GraphConv`](@ref) || | || |
3232
| [`MEGNetConv`](@ref) | | || | |
3333
| [`NNConv`](@ref) | | || | |
34-
| [`ResGatedGraphConv`](@ref) | | | | | |
35-
| [`SAGEConv`](@ref) || | | | |
34+
| [`ResGatedGraphConv`](@ref) | | | | | |
35+
| [`SAGEConv`](@ref) || | | | |
3636
| [`SGConv`](@ref) || | | | |
3737
| [`TransformerConv`](@ref) | | || | |
3838

src/layers/temporalconv.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,20 @@ end
189189

190190
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
191191
return l.(tg.snapshots, x)
192+
end
193+
194+
function (l::GCNConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
195+
return l.(tg.snapshots, x)
196+
end
197+
198+
function (l::ResGatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
199+
return l.(tg.snapshots, x)
200+
end
201+
202+
function (l::SAGEConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
203+
return l.(tg.snapshots, x)
204+
end
205+
206+
function (l::GraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
207+
return l.(tg.snapshots, x)
192208
end

test/layers/temporalconv.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,32 @@ end
3939
@test length(ginconv(tg, tg.ndata.x)) == S
4040
@test size(ginconv(tg, tg.ndata.x)[1]) == (out_channel, N)
4141
@test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S
42+
end
43+
44+
@testset "GCNConv" begin
45+
gcnconv = GCNConv(in_channel => out_channel)
46+
@test length(gcnconv(tg, tg.ndata.x)) == S
47+
@test size(gcnconv(tg, tg.ndata.x)[1]) == (out_channel, N)
48+
@test length(Flux.gradient(x ->sum(sum(gcnconv(tg, x))), tg.ndata.x)[1]) == S
49+
end
50+
51+
@testset "ResGatedGraphConv" begin
52+
resgatedconv = ResGatedGraphConv(in_channel => out_channel, relu)
53+
@test length(resgatedconv(tg, tg.ndata.x)) == S
54+
@test size(resgatedconv(tg, tg.ndata.x)[1]) == (out_channel, N)
55+
@test length(Flux.gradient(x ->sum(sum(resgatedconv(tg, x))), tg.ndata.x)[1]) == S
56+
end
57+
58+
@testset "SAGEConv" begin
59+
sageconv = SAGEConv(in_channel => out_channel)
60+
@test length(sageconv(tg, tg.ndata.x)) == S
61+
@test size(sageconv(tg, tg.ndata.x)[1]) == (out_channel, N)
62+
@test length(Flux.gradient(x ->sum(sum(sageconv(tg, x))), tg.ndata.x)[1]) == S
63+
end
64+
65+
@testset "GraphConv" begin
66+
graphconv = GraphConv(in_channel => out_channel,relu)
67+
@test length(graphconv(tg, tg.ndata.x)) == S
68+
@test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N)
69+
@test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S
4270
end

0 commit comments

Comments
 (0)