Skip to content

Commit 772fc42

Browse files
more tests
1 parent 46f55c9 commit 772fc42

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/GNNGraphs/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ function check_num_nodes(g::GNNGraph, x::AbstractArray)
22
@assert g.num_nodes == size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_edges=$(g.num_nodes)"
33
return true
44
end
5-
function check_num_nodes(g::GNNGraph, x::NamedTuple)
5+
function check_num_nodes(g::GNNGraph, x::Union{Tuple,NamedTuple})
66
map(x -> check_num_nodes(g, x), x)
77
return true
88
end
@@ -13,7 +13,7 @@ function check_num_edges(g::GNNGraph, e::AbstractArray)
1313
@assert g.num_edges == size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(g.num_edges)"
1414
return true
1515
end
16-
function check_num_edges(g::GNNGraph, x::NamedTuple)
16+
function check_num_edges(g::GNNGraph, x::Union{Tuple,NamedTuple})
1717
map(x -> check_num_edges(g, x), x)
1818
return true
1919
end

test/msgpass.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,12 @@
136136
@test spmm_fused(g) X * A
137137
@test spmm_fused2(g) X * A
138138
end
139+
140+
@testset "aggregate_neighbors" begin
141+
m = rand(2, g.num_edges-1)
142+
@test_throws AssertionError aggregate_neighbors(g, +, m)
143+
144+
m = (a=rand(2, g.num_edges+1), b=nothing)
145+
@test_throws AssertionError aggregate_neighbors(g, +, m)
146+
end
139147
end

0 commit comments

Comments
 (0)