Skip to content

Commit d0b0089

Browse files
committed
Add standard aggregation tests
1 parent 0c81482 commit d0b0089

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ end
244244
test_symmetric_soc()
245245
end
246246

247+
@testset "Standard Aggregation" begin
248+
249+
test_standard_aggregation()
250+
end
251+
247252
end
248253

249254
end

test/sa_tests.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,82 @@ function generate_matrices()
5656

5757
cases
5858
end
59+
60+
function stand_agg(C)
61+
n = size(C, 1)
62+
63+
R = Set(1:n)
64+
j = 0
65+
Cpts = Int[]
66+
67+
aggregates = -ones(Int, n)
68+
69+
# Pass 1
70+
for i = 1:n
71+
Ni = union!(Set(C.rowval[nzrange(C, i)]), Set(i))
72+
if issubset(Ni, R)
73+
push!(Cpts, i)
74+
setdiff!(R, Ni)
75+
for x in Ni
76+
aggregates[x] = j
77+
end
78+
j += 1
79+
end
80+
end
81+
82+
# Pass 2
83+
old_R = copy(R)
84+
for i = 1:n
85+
if ! (i in R)
86+
continue
87+
end
88+
89+
for x in C.rowval[nzrange(C, i)]
90+
if !(x in old_R)
91+
aggregates[i] = aggregates[x]
92+
setdiff!(R, i)
93+
break
94+
end
95+
end
96+
end
97+
98+
# Pass 3
99+
for i = 1:n
100+
if !(i in R)
101+
continue
102+
end
103+
Ni = union(Set(C.rowval[nzrange(C,i)]), Set(i))
104+
push!(Cpts, i)
105+
106+
for x in Ni
107+
if x in R
108+
aggregates[x] = j
109+
end
110+
j += 1
111+
end
112+
end
113+
114+
@assert length(R) == 0
115+
116+
Pj = aggregates + 1
117+
Pp = collect(1:n+1)
118+
Px = ones(eltype(C), n)
119+
120+
SparseMatrixCSC(maximum(aggregates + 1), n, Pp, Pj, Px)
121+
end
122+
123+
# Standard aggregation tests
124+
function test_standard_aggregation()
125+
126+
cases = generate_matrices()
127+
128+
for matrix in cases
129+
for θ in (0.0, 0.1, 0.5, 1., 10.)
130+
C = symmetric_soc(matrix, θ)
131+
calc_matrix = aggregation(StandardAggregation(), matrix)
132+
ref_matrix = stand_agg(matrix)
133+
@test sum(abs2, ref_matrix - calc_matrix) < 1e-6
134+
end
135+
end
136+
137+
end

0 commit comments

Comments
 (0)