Skip to content

Commit 46f55c9

Browse files
more size checks
1 parent 9398a53 commit 46f55c9

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

src/GNNGraphs/gatherscatter.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function _scatter(aggr,
1313
n::Integer)
1414

1515
dstsize = (size(src)[1:end-1]..., n)
16-
NNlib.scatter(aggr, src, idx; dstsize)
16+
return NNlib.scatter(aggr, src, idx; dstsize)
1717
end
1818

1919

src/GNNGraphs/utils.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
11
function check_num_nodes(g::GNNGraph, x::AbstractArray)
2-
@assert g.num_nodes == size(x, ndims(x))
2+
@assert g.num_nodes == size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_edges=$(g.num_nodes)"
3+
return true
34
end
5+
function check_num_nodes(g::GNNGraph, x::NamedTuple)
6+
map(x -> check_num_nodes(g, x), x)
7+
return true
8+
end
9+
10+
check_num_nodes(::GNNGraph, ::Nothing) = true
11+
412
function check_num_edges(g::GNNGraph, e::AbstractArray)
5-
@assert g.num_edges == size(e, ndims(e))
13+
@assert g.num_edges == size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(g.num_edges)"
14+
return true
615
end
16+
function check_num_edges(g::GNNGraph, x::NamedTuple)
17+
map(x -> check_num_edges(g, x), x)
18+
return true
19+
end
20+
21+
check_num_edges(::GNNGraph, ::Nothing) = true
22+
723

824
sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...)
925

src/msgpass.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ apply_edges(l, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) =
111111
apply_edges(l, g, xi, xj, e)
112112

113113
function apply_edges(f, g::GNNGraph, xi, xj, e)
114+
check_num_nodes(g, xi)
115+
check_num_nodes(g, xj)
116+
check_num_edges(g, e)
114117
s, t = edge_index(g)
115118
xi = GNNGraphs._gather(xi, t) # size: (D, num_nodes) -> (D, num_edges)
116119
xj = GNNGraphs._gather(xj, s)
@@ -133,6 +136,7 @@ Neighborhood aggregation is the second step of [`propagate`](@ref),
133136
where it comes after [`apply_edges`](@ref).
134137
"""
135138
function aggregate_neighbors(g::GNNGraph, aggr, m)
139+
check_num_edges(g, m)
136140
s, t = edge_index(g)
137141
return GNNGraphs._scatter(aggr, m, t, g.num_nodes)
138142
end

test/msgpass.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@
6262
end
6363

6464
@test m.a == ones(out_channel, num_E)
65+
66+
@testset "sizecheck" begin
67+
x = rand(3, g.num_nodes-1)
68+
@test_throws AssertionError apply_edges(copy_xj, g, xj=x)
69+
@test_throws AssertionError apply_edges(copy_xj, g, xi=x)
70+
71+
x = (a=rand(3, g.num_nodes), b=rand(3, g.num_nodes+1))
72+
@test_throws AssertionError apply_edges(copy_xj, g, xj=x)
73+
@test_throws AssertionError apply_edges(copy_xj, g, xi=x)
74+
75+
e = rand(3, g.num_edges-1)
76+
@test_throws AssertionError apply_edges(copy_xj, g, e=e)
77+
end
6578
end
6679

6780
@testset "copy_xj" begin

0 commit comments

Comments
 (0)