Skip to content

Commit 2d70b0d

Browse files
committed
Add a few fit candidates tests
1 parent d0b0089 commit 2d70b0d

File tree

3 files changed

+61
-101
lines changed

3 files changed

+61
-101
lines changed

src/aggregation.jl

Lines changed: 9 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -81,51 +81,20 @@ construct_R(::HermitianSymmetry, P) = P'
8181

8282
function fit_candidates(AggOp, B, tol = 1e-10)
8383

84-
# K1 = Int(size(B, 1) / N_fine)
85-
# K2 = size(B, 2)
86-
8784
A = AggOp.'
88-
8985
n_coarse = size(A, 2)
9086
n_fine = size(A, 1)
9187
n_col = n_coarse
9288

93-
# R = zeros(eltype(B), N_coarse, K2, K2)
9489
R = zeros(eltype(B), n_coarse)
95-
# Qx = zeros(eltype(B), nnz(AggOp), K1, K2)
9690
Qx = zeros(eltype(B), nnz(A))
9791

98-
99-
# R = vec(R)
100-
# Qx = vec(Qx)
101-
102-
# n_row = N_fine
103-
# n_col = N_coarse
104-
105-
# BS = K1 * K2
106-
107-
#=for i = 1:n_col
108-
Ax_start = 1 + BS * A.colptr[i]
109-
110-
for j in nzrange(A, i)
111-
B_start = 1 + BS * A.rowval[j]
112-
B_end = B_start + BS
113-
@show B_start
114-
@show B_end
115-
for ind in B_start:B_end
116-
A.nzval[ind + Ax_start] = B[ind]
117-
end
118-
Ax_start += BS
119-
end
120-
end=#
121-
# copy!(A.nzval, B)
12292
copy!(Qx, B)
123-
# @show size(A.nzval)
124-
# @show size(B)
93+
copy!(A.nzval, B)
12594

95+
k = 1
12696
for i = 1:n_col
12797
norm_i = norm_col(A, Qx, i)
128-
# norm_i = norm(A[:,i])
12998
threshold_i = tol * norm_i
13099
if norm_i > threshold_i
131100
scale = 1 / norm_i
@@ -136,82 +105,21 @@ function fit_candidates(AggOp, B, tol = 1e-10)
136105
end
137106
for j in nzrange(A, i)
138107
row = A.rowval[j]
139-
Qx[row] *= scale
108+
# Qx[row] *= scale
109+
#@show k
110+
Qx[k] *= scale
111+
k += 1
140112
end
141-
#=col_start = A.colptr[i]
142-
col_end = A.colptr[i+1]
143-
144-
Ax_start = 1 + BS * col_start
145-
Ax_end = 1 + BS * col_end
146-
R_start = 1 + i * K2 * K2
147-
148-
for bj = 1:K2
149-
norm_j = zero(eltype(A))
150-
Ax_col = Ax_start + bj
151-
while Ax_col < Ax_end
152-
norm_j += norm(A.nzval[Ax_col])
153-
Ax_col += K2
154-
end
155-
norm_j = sqrt(norm_j)
156-
157-
threshold_j = tol * norm_j
158-
159-
for bi = 1:bj
160-
dot_prod = zero(eltype(A))
161-
162-
Ax_bi = Ax_start + bj
163-
Ax_bj = Ax_start + bj
164-
while Ax_bi < Ax_end
165-
dot_prod += dot(A.nzval[Ax_bj], A.nzval[Ax_bi])
166-
Ax_bi += K2
167-
Ax_bj += K2
168-
end
169-
170-
Ax_bi = Ax_start + bi;
171-
Ax_bj = Ax_start + bj;
172-
while Ax_bi < Ax_end
173-
A.nzval[Ax_bj] -= dot_prod * A.nzval[Ax_bi]
174-
Ax_bi += K2
175-
Ax_bj += K2
176-
end
177-
178-
R[R_start + K2 * bi + bj] = dot_prod
179-
end
180-
181-
norm_j = zero(eltype(A))
182-
Ax_bj = Ax_start + bj
183-
while Ax_bj < Ax_end
184-
norm_j += norm(A.nzval[Ax_bj])
185-
Ax_bj += K2
186-
norm_j = sqrt(norm_j)
187-
end
188-
189-
if norm_j > threshold_j
190-
scale = 1 / norm_j
191-
R[R_start + K2 * bj + bj] = norm_j
192-
else
193-
scale = zero(eltype(A))
194-
R[R_start + K2 * bj + bj] = 0
195-
end
196-
197-
Ax_bj = Ax_start + bj
198-
while Ax_bj < Ax_end
199-
A.nzval[Ax_bj] *= scale
200-
Ax_bj += K2
201-
end
202-
end=#
203113
end
204114

205-
#Q = SparseMatrixCSC(N_coarse, N_fine,
206-
#.colptr, A.rowval, Qx)
207-
208-
#R = reshape(R, N_coarse, K2)
209115
SparseMatrixCSC(size(A)..., A.colptr, A.rowval, Qx), R
116+
# A, R
210117
end
211118
function norm_col(A, Qx, i)
212119
s = zero(eltype(A))
213120
for j in nzrange(A, i)
214-
val = Qx[A.rowval[j]]
121+
# val = Qx[A.rowval[j]]
122+
val = A.nzval[A.rowval[j]]
215123
s += val*val
216124
end
217125
sqrt(s)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ end
249249
test_standard_aggregation()
250250
end
251251

252+
@testset "Fit Candidates" begin
253+
test_fit_candidates()
254+
end
252255
end
253256

254257
end

test/sa_tests.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,53 @@ function test_standard_aggregation()
134134
end
135135
end
136136

137+
end
138+
139+
# Test fit_candidates
140+
function test_fit_candidates()
141+
142+
cases = generate_fit_candidates_cases()
143+
144+
for (i, (AggOp, fine_candidates)) in enumerate(cases)
145+
146+
mask_candidates!(AggOp, fine_candidates)
147+
148+
Q, coarse_candidates = fit_candidates(AggOp, fine_candidates)
149+
150+
@test isapprox(fine_candidates, Q * coarse_candidates)
151+
@test isapprox(Q * (Q' * fine_candidates), fine_candidates)
152+
end
153+
end
154+
function mask_candidates!(A,B)
155+
B[(diff(A.colptr) .== 0)] = 0
156+
end
157+
158+
function generate_fit_candidates_cases()
159+
cases = []
160+
161+
for T in (Float32, Float64)
162+
163+
# One candidate
164+
AggOp = SparseMatrixCSC(2, 5, collect(1:6),
165+
[1,1,1,2,2], ones(T,5))
166+
B = ones(T,5)
167+
push!(cases, (AggOp, B))
168+
169+
AggOp = SparseMatrixCSC(2, 5, collect(1:6),
170+
[2,2,1,1,1], ones(T,5))
171+
B = ones(T, 5)
172+
push!(cases, (AggOp, B))
173+
174+
AggOp = SparseMatrixCSC(3, 9, collect(1:10),
175+
[1,1,1,2,2,2,3,3,3], ones(T, 9))
176+
B = ones(T, 9)
177+
push!(cases, (AggOp, B))
178+
179+
#AggOp = SparseMatrixCSC(3, 9, collect(1:10),
180+
#[3,2,1,1,2,3,2,1,3], ones(T,9))
181+
#B = T.(collect(1:9))
182+
#push!(cases, (AggOp, B))
183+
end
184+
185+
cases
137186
end

0 commit comments

Comments
 (0)