Skip to content

Commit 84a99ec

Browse files
committed
assert squareness of adj_matrix
1 parent 609c6e6 commit 84a99ec

File tree

7 files changed

+38
-8
lines changed

7 files changed

+38
-8
lines changed

src/NetworkLayout.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ struct LayoutIterator{T<:IterativeLayout,M<:AbstractMatrix}
6969
end
7070

7171
function layout(alg::IterativeLayout, adj_matrix::AbstractMatrix)
72+
assertsquare(adj_matrix)
7273
iter = LayoutIterator(alg, adj_matrix)
7374
next = Base.iterate(iter)
7475
pos = next[1]
@@ -84,6 +85,17 @@ function __init__()
8485
@require LightGraphs="093fc24a-ae57-5d10-9952-331d41423f4d" layout(l::AbstractLayout, g::LightGraphs.AbstractGraph) = layout(l, LightGraphs.adjacency_matrix(g))
8586
end
8687

88+
"""
89+
assertsquare(M)
90+
91+
Throws `ArgumentArror` if matrix is not square. Returns size.
92+
"""
93+
function assertsquare(M::AbstractMatrix)
94+
(a, b) = size(M)
95+
a != b && throw(ArgumentError("Adjecency Matrix needs to be square!"))
96+
return a
97+
end
98+
8799
include("sfdp.jl")
88100
include("buchheim.jl")
89101
include("spring.jl")

src/buchheim.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function adj_mat_to_list(M::AbstractMatrix)
6464
end
6565

6666
function layout(para::Buchheim, adj_matrix::AbstractMatrix)
67-
@assert size(adj_matrix, 1) == size(adj_matrix, 2) "adjacency matrix not square!"
67+
assertsquare(adj_matrix)
6868
list = adj_mat_to_list(adj_matrix)
6969
layout(para, list)
7070
end

src/circular.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ struct Circular{Ptype} <: AbstractLayout{2,Ptype} end
1717
Circular(; Ptype=Float64) = Circular{Ptype}()
1818

1919
function layout(::Circular{Ptype}, adj_matrix::AbstractMatrix) where {Ptype}
20-
if size(adj_matrix, 1) == 1
20+
N = assertsquare(adj_matrix)
21+
if N == 1
2122
return Point{2,Ptype}[Point(0.0, 0.0)]
2223
else
2324
# Discard the extra angle since it matches 0 radians.
24-
θ = range(0; stop=2pi, length=size(adj_matrix, 1) + 1)[1:(end - 1)]
25+
θ = range(0; stop=2pi, length=N + 1)[1:(end - 1)]
2526
return Point{2,Ptype}[(cos(o), sin(o)) for o in θ]
2627
end
2728
end

src/shell.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ end
2626
Shell(; Ptype=Float64, nlist=Vector{Int}[]) = Shell{Ptype}(nlist)
2727

2828
function layout(algo::Shell{Ptype}, adj_matrix::AbstractMatrix) where {Ptype}
29-
if size(adj_matrix, 1) == 1
29+
N = assertsquare(adj_matrix)
30+
if N == 1
3031
return Point{2,Float64}[Point(0.0, 0.0)]
3132
end
3233

33-
N = size(adj_matrix, 1)
3434
nlist = copy(algo.nlist)
3535

3636
# if the list does not contain all the nodes push missing nodes to new shell

src/spectral.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ function compute_laplacian(adj_matrix, node_weights)
5656
end
5757

5858
function layout(algo::Spectral{Ptype,FT}, adj_matrix::AbstractMatrix) where {Ptype,FT}
59+
N = assertsquare(adj_matrix)
5960
# try to use user provided nodeweights
60-
nodeweights = if length(algo.nodeweights) == size(adj_matrix, 1)
61+
nodeweights = if length(algo.nodeweights) == N
6162
algo.nodeweights
6263
else
63-
ones(FT, size(adj_matrix, 1))
64+
ones(FT, N)
6465
end
6566

6667
adj_matrix = make_symmetric(adj_matrix)

src/squaregrid.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function SquareGrid(; Ptype=Float64, cols=:auto, dx=Ptype(1), dy=Ptype(-1), skip
3030
end
3131

3232
function layout(algo::SquareGrid{Ptype}, adj_matrix::AbstractMatrix) where {Ptype}
33-
N = size(adj_matrix, 1)
33+
N = assertsquare(adj_matrix)
3434
M = N + length(algo.skip)
3535

3636
if algo.cols === :auto

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,4 +281,20 @@ jagmesh_adj = jagmesh()
281281
pos = Spring()(g)
282282
@test pos isa Vector{Point{2,Float64}}
283283
end
284+
285+
@testset "test assert square" begin
286+
using NetworkLayout: assertsquare
287+
M1 = rand(2,4)
288+
@test_throws ArgumentError assertsquare(M1)
289+
M2 = rand(4,4)
290+
@test assertsquare(M2) == 4
291+
@test_throws ArgumentError layout(Buchheim(), M1)
292+
@test_throws ArgumentError layout(Circular(), M1)
293+
@test_throws ArgumentError layout(SFDP(), M1)
294+
@test_throws ArgumentError layout(Shell(), M1)
295+
@test_throws ArgumentError layout(Spectral(), M1)
296+
@test_throws ArgumentError layout(Spring(), M1)
297+
@test_throws ArgumentError layout(SquareGrid(), M1)
298+
@test_throws ArgumentError layout(Stress(), M1)
299+
end
284300
end

0 commit comments

Comments
 (0)