Skip to content

Commit 9541594

Browse files
Merge pull request #139 from CarloLucibello/cl/isolated
handle isolated nodes
2 parents 77f0e0f + a6a9c49 commit 9541594

File tree

5 files changed

+48
-30
lines changed

5 files changed

+48
-30
lines changed

src/GNNGraphs/gatherscatter.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x)
33
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
44
_gather(x::Nothing, i) = nothing
55

6-
_scatter(aggr, m::NamedTuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
7-
_scatter(aggr, m::Tuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
8-
_scatter(aggr, m::AbstractArray, t; dstsize=nothing) = NNlib.scatter(aggr, m, t; dstsize)
9-
_scatter(aggr, m::Nothing, t; dstsize=nothing) = nothing
6+
_scatter(aggr, src::Nothing, idx, n) = nothing
7+
_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
8+
_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
9+
10+
function _scatter(aggr,
11+
src::AbstractArray,
12+
idx::AbstractVector{<:Integer},
13+
n::Integer)
14+
15+
dstsize = (size(src)[1:end-1]..., n)
16+
NNlib.scatter(aggr, src, idx; dstsize)
17+
end
18+
1019

1120
## TO MOVE TO NNlib ######################################################
1221

src/GNNGraphs/transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr=+)
109109
idxs .= 1:num_edges
110110
idxs .= idxs .- cumsum(.!mask)
111111
num_edges = length(s)
112-
w = _scatter(aggr, w, idxs)
113-
edata = _scatter(aggr, edata, idxs)
112+
w = _scatter(aggr, w, idxs, num_edges)
113+
edata = _scatter(aggr, edata, idxs, num_edges)
114114
end
115115

116116
GNNGraph((s, t, w),

src/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ where it comes after [`apply_edges`](@ref).
134134
"""
135135
function aggregate_neighbors(g::GNNGraph, aggr, m)
136136
s, t = edge_index(g)
137-
return GNNGraphs._scatter(aggr, m, t)
137+
return GNNGraphs._scatter(aggr, m, t, g.num_nodes)
138138
end
139139

140140

test/layers/conv.jl

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
@testset "Conv Layers" begin
2+
RTOL_LOW = 1e-2
3+
RTOL_HIGH = 1e-5
4+
25
in_channel = 3
36
out_channel = 5
47
N = 4
@@ -27,16 +30,16 @@
2730
@testset "GCNConv" begin
2831
l = GCNConv(in_channel => out_channel)
2932
for g in test_graphs
30-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
33+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
3134
end
3235

3336
l = GCNConv(in_channel => out_channel, tanh, bias=false)
3437
for g in test_graphs
35-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
38+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
3639
end
3740

3841
l = GCNConv(in_channel => out_channel, add_self_loops=false)
39-
test_layer(l, g1, rtol=1e-5, outsize=(out_channel, g1.num_nodes))
42+
test_layer(l, g1, rtol=RTOL_HIGH, outsize=(out_channel, g1.num_nodes))
4043

4144
@testset "edge weights" begin
4245
s = [2,3,1,3,1,2]
@@ -57,7 +60,7 @@
5760
x = rand(T, 1, 3)
5861
g = GNNGraph((s, t, w), ndata=x, graph_type=GRAPH_T, edata=w)
5962
l = GCNConv(1 => 1, add_self_loops=false, use_edge_weight=true)
60-
test_layer(l, g, rtol=1e-5, outsize=(1, g.num_nodes), test_gpu=false)
63+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(1, g.num_nodes), test_gpu=false)
6164
@test gradient(w -> sum(l(g, x, w)), w)[1] isa AbstractVector{T} # redundan test but more esplicit
6265
end
6366
end
@@ -70,9 +73,9 @@
7073
@test l.k == k
7174
for g in test_graphs
7275
g = add_self_loops(g)
73-
test_layer(l, g, rtol=1e-5, test_gpu=false, outsize=(out_channel, g.num_nodes))
76+
test_layer(l, g, rtol=RTOL_HIGH, test_gpu=false, outsize=(out_channel, g.num_nodes))
7477
if TEST_GPU
75-
@test_broken test_layer(l, g, rtol=1e-5, test_gpu=true, outsize=(out_channel, g.num_nodes))
78+
@test_broken test_layer(l, g, rtol=RTOL_HIGH, test_gpu=true, outsize=(out_channel, g.num_nodes))
7679
end
7780
end
7881

@@ -85,12 +88,12 @@
8588
@testset "GraphConv" begin
8689
l = GraphConv(in_channel => out_channel)
8790
for g in test_graphs
88-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
91+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
8992
end
9093

9194
l = GraphConv(in_channel => out_channel, relu, bias=false, aggr=mean)
9295
for g in test_graphs
93-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
96+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
9497
end
9598

9699
@testset "bias=false" begin
@@ -104,7 +107,7 @@
104107
for heads in (1, 2), concat in (true, false)
105108
l = GATConv(in_channel => out_channel; heads, concat)
106109
for g in test_graphs
107-
test_layer(l, g, rtol=1e-3,
110+
test_layer(l, g, rtol=RTOL_LOW,
108111
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
109112
end
110113
end
@@ -113,7 +116,7 @@
113116
ein = 3
114117
l = GATConv((in_channel, ein) => out_channel, add_self_loops=false)
115118
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
116-
test_layer(l, g, rtol=1e-3, outsize=(out_channel, g.num_nodes))
119+
test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes))
117120
end
118121

119122
@testset "num params" begin
@@ -131,7 +134,7 @@
131134
for heads in (1, 2), concat in (true, false)
132135
l = GATv2Conv(in_channel => out_channel, tanh; heads, concat)
133136
for g in test_graphs
134-
test_layer(l, g, rtol=1e-3,
137+
test_layer(l, g, rtol=RTOL_LOW,
135138
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
136139
end
137140
end
@@ -140,7 +143,7 @@
140143
ein = 3
141144
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false)
142145
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
143-
test_layer(l, g, rtol=1e-3, outsize=(out_channel, g.num_nodes))
146+
test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes))
144147
end
145148

146149
@testset "num params" begin
@@ -167,14 +170,14 @@
167170
@test size(l.weight) == (out_channel, out_channel, num_layers)
168171

169172
for g in test_graphs
170-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
173+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
171174
end
172175
end
173176

174177
@testset "EdgeConv" begin
175178
l = EdgeConv(Dense(2*in_channel, out_channel), aggr=+)
176179
for g in test_graphs
177-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
180+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
178181
end
179182
end
180183

@@ -183,7 +186,7 @@
183186

184187
l = GINConv(nn, 0.01f0, aggr=mean)
185188
for g in test_graphs
186-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
189+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
187190
end
188191

189192
@test !in(:eps, Flux.trainable(l))
@@ -196,7 +199,7 @@
196199
l = NNConv(in_channel => out_channel, nn, tanh, bias=true, aggr=+)
197200
for g in test_graphs
198201
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
199-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
202+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
200203
end
201204
end
202205

@@ -206,15 +209,15 @@
206209

207210
l = SAGEConv(in_channel => out_channel, tanh, bias=false, aggr=+)
208211
for g in test_graphs
209-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
212+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
210213
end
211214
end
212215

213216

214217
@testset "ResGatedGraphConv" begin
215218
l = ResGatedGraphConv(in_channel => out_channel, tanh, bias=true)
216219
for g in test_graphs
217-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
220+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
218221
end
219222
end
220223

@@ -224,7 +227,7 @@
224227
l = CGConv((in_channel, edim) => out_channel, tanh, residual=false, bias=true)
225228
for g in test_graphs
226229
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
227-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
230+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
228231
end
229232

230233
# no edge features
@@ -238,15 +241,15 @@
238241
l = AGNNConv()
239242
@test l.β == [1f0]
240243
for g in test_graphs
241-
test_layer(l, g, rtol=1e-5, outsize=(in_channel, g.num_nodes))
244+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(in_channel, g.num_nodes))
242245
end
243246
end
244247

245248
@testset "MEGNetConv" begin
246249
l = MEGNetConv(in_channel => out_channel, aggr=+)
247250
for g in test_graphs
248251
g = GNNGraph(g, edata=rand(T, in_channel, g.num_edges))
249-
test_layer(l, g, rtol=1e-3,
252+
test_layer(l, g, rtol=RTOL_LOW,
250253
outtype=:node_edge,
251254
outsize=((out_channel, g.num_nodes), (out_channel, g.num_edges)))
252255
end

test/msgpass.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,14 @@
2828
m = propagate(message, g, +, xj=X)
2929

3030
@test size(m) == (out_channel, num_V)
31-
end
3231

32+
@testset "isolated nodes" begin
33+
x1 = rand(1, 6)
34+
g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes=6)
35+
y1 = propagate((xi,xj,e) -> xj, g, +, xj=x1)
36+
@test size(y1) == (1, 6)
37+
end
38+
end
3339

3440
@testset "apply_edges" begin
3541
m = apply_edges(g, e=E) do xi, xj, e
@@ -85,7 +91,7 @@
8591
@test spmm_copyxj_fused(g) X * Adj
8692
end
8793

88-
@testset "e_mul_xj adn w_mul_xj for weighted conv" begin
94+
@testset "e_mul_xj and w_mul_xj for weighted conv" begin
8995
n = 128
9096
A = sprand(n, n, 0.1)
9197
Adj = map(x -> x > 0 ? 1 : 0, A)

0 commit comments

Comments
 (0)