Skip to content

Commit 2319a89

Browse files
handle isolated nodes
1 parent 77f0e0f commit 2319a89

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

src/GNNGraphs/gatherscatter.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x)
33
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
44
_gather(x::Nothing, i) = nothing
55

6+
_scatter(aggr, m::Nothing, t; dstsize=nothing) = nothing
67
_scatter(aggr, m::NamedTuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
78
_scatter(aggr, m::Tuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
8-
_scatter(aggr, m::AbstractArray, t; dstsize=nothing) = NNlib.scatter(aggr, m, t; dstsize)
9-
_scatter(aggr, m::Nothing, t; dstsize=nothing) = nothing
10-
9+
function _scatter(aggr, m::AbstractMatrix, t; dstsize=nothing)
10+
if dstsize !== nothing
11+
dstsize = (size(m,1), dstsize)
12+
end
13+
NNlib.scatter(aggr, m, t; dstsize)
14+
end
1115
## TO MOVE TO NNlib ######################################################
1216

1317

src/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ where it comes after [`apply_edges`](@ref).
134134
"""
135135
function aggregate_neighbors(g::GNNGraph, aggr, m)
136136
s, t = edge_index(g)
137-
return GNNGraphs._scatter(aggr, m, t)
137+
return GNNGraphs._scatter(aggr, m, t; dstsize=g.num_nodes)
138138
end
139139

140140

test/msgpass.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,14 @@
2828
m = propagate(message, g, +, xj=X)
2929

3030
@test size(m) == (out_channel, num_V)
31-
end
3231

32+
@testset "isolated nodes" begin
33+
x1 = rand(1, 6)
34+
g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes=6)
35+
y1 = propagate((xi,xj,e) -> xj, g, +, xj=x1)
36+
@test size(y1) == (1, 6)
37+
end
38+
end
3339

3440
@testset "apply_edges" begin
3541
m = apply_edges(g, e=E) do xi, xj, e
@@ -85,7 +91,7 @@
8591
@test spmm_copyxj_fused(g) X * Adj
8692
end
8793

88-
@testset "e_mul_xj adn w_mul_xj for weighted conv" begin
94+
@testset "e_mul_xj and w_mul_xj for weighted conv" begin
8995
n = 128
9096
A = sprand(n, n, 0.1)
9197
Adj = map(x -> x > 0 ? 1 : 0, A)

0 commit comments

Comments
 (0)