Skip to content

Commit c1f2d91

Browse files
committed
DArray: Fix norm of single-chunk upper/lower array
1 parent cd8ca95 commit c1f2d91

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

src/array/linalg.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
function LinearAlgebra.norm2(A::DArray{T,2}) where T
22
Ac = A.chunks
33
norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]::Matrix{DTask}
4-
return sqrt(sum(map(norm->fetch(norm)::real(T), norms)))
4+
zeroRT = zero(real(T))
5+
return sqrt(sum(map(norm->fetch(norm)::real(T), norms); init=zeroRT))
56
end
67
function LinearAlgebra.norm2(A::UpperTriangular{T,<:DArray{T,2}}) where T
78
Ac = parent(A).chunks
@@ -12,7 +13,10 @@ function LinearAlgebra.norm2(A::UpperTriangular{T,<:DArray{T,2}}) where T
1213
upper_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_upper]
1314
Ac_diag = [Dagger.spawn(UpperTriangular, Ac[i,i]) for i in 1:size(Ac, 1)]
1415
diag_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_diag]
15-
return sqrt(sum(map(fetch, upper_norms)) + sum(map(fetch, diag_norms)))
16+
upper_norms_values = map(fetch, upper_norms)
17+
diag_norms_values = map(fetch, diag_norms)
18+
zeroRT = zero(real(T))
19+
return sqrt(sum(upper_norms_values; init=zeroRT) + sum(diag_norms_values; init=zeroRT))
1620
end
1721
function LinearAlgebra.norm2(A::LowerTriangular{T,<:DArray{T,2}}) where T
1822
Ac = parent(A).chunks
@@ -23,7 +27,10 @@ function LinearAlgebra.norm2(A::LowerTriangular{T,<:DArray{T,2}}) where T
2327
lower_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_lower]
2428
Ac_diag = [Dagger.spawn(LowerTriangular, Ac[i,i]) for i in 1:size(Ac, 1)]
2529
diag_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_diag]
26-
return sqrt(sum(map(fetch, lower_norms)) + sum(map(fetch, diag_norms)))
30+
lower_norms_values = map(fetch, lower_norms)
31+
diag_norms_values = map(fetch, diag_norms)
32+
zeroRT = zero(real(T))
33+
return sqrt(sum(lower_norms_values; init=zeroRT) + sum(diag_norms_values; init=zeroRT))
2734
end
2835

2936
is_cross_symmetric(A1, A2) = A1 == A2'

test/array/linalg/core.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@testset "isapprox" begin
2+
A = rand(16, 16)
3+
4+
U1 = UpperTriangular(DArray(A, Blocks(16, 16)))
5+
U2 = UpperTriangular(DArray(A, Blocks(16, 16)))
6+
@test isapprox(U1, U2)
7+
8+
L1 = LowerTriangular(DArray(A, Blocks(16, 16)))
9+
L2 = LowerTriangular(DArray(A, Blocks(16, 16)))
10+
@test isapprox(L1, L2)
11+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ tests = [
2727
("Array - Core", "array/core.jl"),
2828
("Array - Copyto", "array/copyto.jl"),
2929
("Array - MapReduce", "array/mapreduce.jl"),
30+
("Array - LinearAlgebra - Core", "array/linalg/core.jl"),
3031
("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"),
3132
("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"),
3233
("Array - LinearAlgebra - LU", "array/linalg/lu.jl"),

0 commit comments

Comments
 (0)