Skip to content

Commit 50ebac3

Browse files
authored
Add more conv layers with TemporalSnapshotsGNNGraphs support (#393)
* temporal layers * ticks in md file
1 parent 724bdc2 commit 50ebac3

File tree

3 files changed

+92
-9
lines changed

3 files changed

+92
-9
lines changed

docs/src/api/conv.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@ The table below lists all graph convolutional layers implemented in the *GraphNe
1818
| Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
1919
| :-------- | :---: |:---: |:---: | :---: | :---: |
2020
| [`AGNNConv`](@ref) | | || | |
21-
| [`CGConv`](@ref) | | || | |
22-
| [`ChebConv`](@ref) | | | | | |
21+
| [`CGConv`](@ref) | | || | |
22+
| [`ChebConv`](@ref) | | | | | |
2323
| [`EGNNConv`](@ref) | | || | |
2424
| [`EdgeConv`](@ref) | | | | | |
25-
| [`GATConv`](@ref) | | || | |
26-
| [`GATv2Conv`](@ref) | | || | |
27-
| [`GatedGraphConv`](@ref) || | | | |
28-
| [`GCNConv`](@ref) ||| | | |
25+
| [`GATConv`](@ref) | | || | |
26+
| [`GATv2Conv`](@ref) | | || | |
27+
| [`GatedGraphConv`](@ref) || | | | |
28+
| [`GCNConv`](@ref) ||| | | |
2929
| [`GINConv`](@ref) || | | ||
3030
| [`GMMConv`](@ref) | | || | |
3131
| [`GraphConv`](@ref) || | |||
3232
| [`MEGNetConv`](@ref) | | || | |
3333
| [`NNConv`](@ref) | | || | |
3434
| [`ResGatedGraphConv`](@ref) | | | | ||
3535
| [`SAGEConv`](@ref) || | | ||
36-
| [`SGConv`](@ref) || | | | |
36+
| [`SGConv`](@ref) || | | | |
3737
| [`TransformerConv`](@ref) | | || | |
3838

3939

src/layers/temporalconv.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,38 @@ function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
191191
return l.(tg.snapshots, x)
192192
end
193193

194+
function (l::ChebConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
195+
return l.(tg.snapshots, x)
196+
end
197+
198+
function (l::GATConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
199+
return l.(tg.snapshots, x)
200+
end
201+
202+
function (l::GATv2Conv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
203+
return l.(tg.snapshots, x)
204+
end
205+
206+
function (l::GatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
207+
return l.(tg.snapshots, x)
208+
end
209+
210+
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
211+
return l.(tg.snapshots, x)
212+
end
213+
214+
function (l::CGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
215+
return l.(tg.snapshots, x)
216+
end
217+
218+
function (l::SGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
219+
return l.(tg.snapshots, x)
220+
end
221+
222+
function (l::TransformerConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
223+
return l.(tg.snapshots, x)
224+
end
225+
194226
function (l::GCNConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
195227
return l.(tg.snapshots, x)
196228
end
@@ -205,4 +237,4 @@ end
205237

206238
function (l::GraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
207239
return l.(tg.snapshots, x)
208-
end
240+
end

test/layers/temporalconv.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,56 @@ end
4141
@test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S
4242
end
4343

44+
45+
@testset "ChebConv" begin
46+
chebconv = ChebConv(in_channel => out_channel, 5)
47+
@test length(chebconv(tg, tg.ndata.x)) == S
48+
@test size(chebconv(tg, tg.ndata.x)[1]) == (out_channel, N)
49+
@test length(Flux.gradient(x ->sum(sum(chebconv(tg, x))), tg.ndata.x)[1]) == S
50+
end
51+
52+
@testset "GATConv" begin
53+
gatconv = GATConv(in_channel => out_channel)
54+
@test length(gatconv(tg, tg.ndata.x)) == S
55+
@test size(gatconv(tg, tg.ndata.x)[1]) == (out_channel, N)
56+
@test length(Flux.gradient(x ->sum(sum(gatconv(tg, x))), tg.ndata.x)[1]) == S
57+
end
58+
59+
@testset "GATv2Conv" begin
60+
gatv2conv = GATv2Conv(in_channel => out_channel)
61+
@test length(gatv2conv(tg, tg.ndata.x)) == S
62+
@test size(gatv2conv(tg, tg.ndata.x)[1]) == (out_channel, N)
63+
@test length(Flux.gradient(x ->sum(sum(gatv2conv(tg, x))), tg.ndata.x)[1]) == S
64+
end
65+
66+
@testset "GatedGraphConv" begin
67+
gatedgraphconv = GatedGraphConv(5, 5)
68+
@test length(gatedgraphconv(tg, tg.ndata.x)) == S
69+
@test size(gatedgraphconv(tg, tg.ndata.x)[1]) == (out_channel, N)
70+
@test length(Flux.gradient(x ->sum(sum(gatedgraphconv(tg, x))), tg.ndata.x)[1]) == S
71+
end
72+
73+
@testset "CGConv" begin
74+
cgconv = CGConv(in_channel => out_channel)
75+
@test length(cgconv(tg, tg.ndata.x)) == S
76+
@test size(cgconv(tg, tg.ndata.x)[1]) == (out_channel, N)
77+
@test length(Flux.gradient(x ->sum(sum(cgconv(tg, x))), tg.ndata.x)[1]) == S
78+
end
79+
80+
@testset "SGConv" begin
81+
sgconv = SGConv(in_channel => out_channel)
82+
@test length(sgconv(tg, tg.ndata.x)) == S
83+
@test size(sgconv(tg, tg.ndata.x)[1]) == (out_channel, N)
84+
@test length(Flux.gradient(x ->sum(sum(sgconv(tg, x))), tg.ndata.x)[1]) == S
85+
end
86+
87+
@testset "TransformerConv" begin
88+
transformerconv = TransformerConv(in_channel => out_channel)
89+
@test length(transformerconv(tg, tg.ndata.x)) == S
90+
@test size(transformerconv(tg, tg.ndata.x)[1]) == (out_channel, N)
91+
@test length(Flux.gradient(x ->sum(sum(transformerconv(tg, x))), tg.ndata.x)[1]) == S
92+
end
93+
4494
@testset "GCNConv" begin
4595
gcnconv = GCNConv(in_channel => out_channel)
4696
@test length(gcnconv(tg, tg.ndata.x)) == S
@@ -67,4 +117,5 @@ end
67117
@test length(graphconv(tg, tg.ndata.x)) == S
68118
@test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N)
69119
@test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S
70-
end
120+
end
121+

0 commit comments

Comments
 (0)