Skip to content

Commit 6be24a6

Browse files
committed
Fix more matrix functions
1 parent 5336dfc commit 6be24a6

File tree

2 files changed

+60
-32
lines changed

2 files changed

+60
-32
lines changed

src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,33 @@ const MATRIX_FUNCTIONS = [
7373
:acoth,
7474
]
7575

76+
# Functions where the dense implementations in `LinearAlgebra` are
77+
# not type stable.
78+
const MATRIX_FUNCTIONS_UNSTABLE = [
79+
:log,
80+
:sqrt,
81+
:acos,
82+
:asin,
83+
:atan,
84+
:acsc,
85+
:asec,
86+
:acot,
87+
:acosh,
88+
:asinh,
89+
:atanh,
90+
:acsch,
91+
:asech,
92+
:acoth,
93+
]
94+
95+
function initialize_output_blocksparse(f::F, a::AbstractMatrix) where {F}
96+
B = Base.promote_op(f, blocktype(a))
97+
return similar(a, BlockType(B))
98+
end
99+
76100
function matrix_function_blocksparse(f::F, a::AbstractMatrix; kwargs...) where {F}
77101
blockisdiag(a) || throw(ArgumentError("`$f` only defined for block-diagonal matrices"))
78-
B = Base.promote_op(f, blocktype(a))
79-
fa = similar(a, BlockType(B))
102+
fa = initialize_output_blocksparse(f, a)
80103
for I in blockdiagindices(a)
81104
fa[I] = f(a[I]; kwargs...)
82105
end
@@ -91,6 +114,15 @@ for f in MATRIX_FUNCTIONS
91114
end
92115
end
93116

117+
for f in MATRIX_FUNCTIONS_UNSTABLE
118+
@eval begin
119+
function initialize_output_blocksparse(::typeof($f), a::AbstractMatrix)
120+
B = similartype(blocktype(a), complex(eltype(a)))
121+
return similar(a, BlockType(B))
122+
end
123+
end
124+
end
125+
94126
function LinearAlgebra.inv(a::AnyAbstractBlockSparseMatrix)
95127
return matrix_function_blocksparse(inv, a)
96128
end

test/test_factorizations.jl

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,7 @@ using Random: Random
3232
using StableRNGs: StableRNG
3333
using Test: @inferred, @test, @test_throws, @testset
3434

35-
# These functions involve inverses so break when there are zeros on the diagonal.
36-
MATRIX_FUNCTIONS_SINGULAR = [:csc, :cot, :csch, :coth]
37-
38-
# Broken because of type stability issues. Fix manually by forcing to be complex.
39-
MATRIX_FUNCTIONS_UNSTABLE = [
40-
:log,
41-
:sqrt,
42-
:acos,
43-
:asin,
44-
:atan,
45-
:acsc,
46-
:asec,
47-
:acot,
48-
:acosh,
49-
:asinh,
50-
:atanh,
51-
:acsch,
52-
:asech,
53-
:acoth,
54-
]
55-
56-
@testset "Matrix functions (eltype=$elt)" for elt in (Float32, Float64, ComplexF64)
35+
@testset "Matrix functions (T=$elt)" for elt in (Float32, Float64, ComplexF64)
5736
rng = StableRNG(123)
5837
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
5938
a[Block(1, 1)] = randn(rng, elt, 2, 2)
@@ -62,8 +41,6 @@ MATRIX_FUNCTIONS_UNSTABLE = [
6241
MATRIX_FUNCTIONS = [MATRIX_FUNCTIONS; [:inv, :pinv]]
6342
# Only works when real, also isn't defined in Julia 1.10.
6443
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
65-
# Broken because of type stability issues. Fix manually by forcing to be complex.
66-
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_UNSTABLE)
6744
for f in MATRIX_FUNCTIONS
6845
@eval begin
6946
fa = $f($a)
@@ -73,15 +50,27 @@ MATRIX_FUNCTIONS_UNSTABLE = [
7350
end
7451
end
7552

76-
# Skip inverse functions when there are missing/zero diagonal blocks.
53+
# Catch case of off-diagonal blocks.
54+
rng = StableRNG(123)
55+
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
56+
a[Block(1, 1)] = randn(rng, elt, 2, 2)
57+
a[Block(1, 2)] = randn(rng, elt, 2, 3)
58+
for f in MATRIX_FUNCTIONS
59+
@eval begin
60+
@test_throws ArgumentError $f($a)
61+
end
62+
end
63+
64+
# Missing diagonal blocks.
7765
rng = StableRNG(123)
7866
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
7967
a[Block(2, 2)] = randn(rng, elt, 3, 3)
8068
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
81-
# These functions involve inverses so break when there are zeros on the diagonal.
69+
# These functions involve inverses so they break when there are zeros on the diagonal.
70+
MATRIX_FUNCTIONS_SINGULAR = [
71+
:log, :acsc, :asec, :acot, :acsch, :asech, :acoth, :csc, :cot, :csch, :coth
72+
]
8273
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_SINGULAR)
83-
# Broken because of type stability issues. Fix manually by forcing to be complex.
84-
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_UNSTABLE)
8574
# Dense version is broken for some reason, investigate.
8675
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
8776
for f in MATRIX_FUNCTIONS
@@ -92,9 +81,16 @@ MATRIX_FUNCTIONS_UNSTABLE = [
9281
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
9382
end
9483
end
95-
for f in MATRIX_FUNCTIONS_SINGULAR
84+
85+
SINGULAR_EXCEPTION = if VERSION < v"1.11-"
86+
# A different exception is thrown in older versions of Julia.
87+
LinearAlgebra.LAPACKException
88+
else
89+
LinearAlgebra.SingularException
90+
end
91+
for f in setdiff(MATRIX_FUNCTIONS_SINGULAR, [:log])
9692
@eval begin
97-
@test_throws LinearAlgebra.SingularException $f($a)
93+
@test_throws $SINGULAR_EXCEPTION $f($a)
9894
end
9995
end
10096
end

0 commit comments

Comments
 (0)