Skip to content

Commit 6be3e32

Browse files
authored
Start adding support for matrix functions (#135)
1 parent 81003b1 commit 6be3e32

File tree

3 files changed

+186
-4
lines changed

3 files changed

+186
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.3"
4+
version = "0.7.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,102 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix)
3232
end
3333
return tr_a
3434
end
35+
36+
# TODO: Define `SparseArraysBase.isdiag`, define as
37+
# `isdiag(blocks(a))`.
38+
function blockisdiag(a::AbstractArray)
39+
return all(eachblockstoredindex(a)) do I
40+
return allequal(Tuple(I))
41+
end
42+
end
43+
44+
const MATRIX_FUNCTIONS = [
45+
:exp,
46+
:cis,
47+
:log,
48+
:sqrt,
49+
:cbrt,
50+
:cos,
51+
:sin,
52+
:tan,
53+
:csc,
54+
:sec,
55+
:cot,
56+
:cosh,
57+
:sinh,
58+
:tanh,
59+
:csch,
60+
:sech,
61+
:coth,
62+
:acos,
63+
:asin,
64+
:atan,
65+
:acsc,
66+
:asec,
67+
:acot,
68+
:acosh,
69+
:asinh,
70+
:atanh,
71+
:acsch,
72+
:asech,
73+
:acoth,
74+
]
75+
76+
# Functions where the dense implementations in `LinearAlgebra` are
77+
# not type stable.
78+
const MATRIX_FUNCTIONS_UNSTABLE = [
79+
:log,
80+
:sqrt,
81+
:acos,
82+
:asin,
83+
:atan,
84+
:acsc,
85+
:asec,
86+
:acot,
87+
:acosh,
88+
:asinh,
89+
:atanh,
90+
:acsch,
91+
:asech,
92+
:acoth,
93+
]
94+
95+
function initialize_output_blocksparse(f::F, a::AbstractMatrix) where {F}
96+
B = Base.promote_op(f, blocktype(a))
97+
return similar(a, BlockType(B))
98+
end
99+
100+
function matrix_function_blocksparse(f::F, a::AbstractMatrix; kwargs...) where {F}
101+
blockisdiag(a) || throw(ArgumentError("`$f` only defined for block-diagonal matrices"))
102+
fa = initialize_output_blocksparse(f, a)
103+
for I in blockdiagindices(a)
104+
fa[I] = f(a[I]; kwargs...)
105+
end
106+
return fa
107+
end
108+
109+
for f in MATRIX_FUNCTIONS
110+
@eval begin
111+
function Base.$f(a::AnyAbstractBlockSparseMatrix)
112+
return matrix_function_blocksparse($f, a)
113+
end
114+
end
115+
end
116+
117+
for f in MATRIX_FUNCTIONS_UNSTABLE
118+
@eval begin
119+
function initialize_output_blocksparse(::typeof($f), a::AbstractMatrix)
120+
B = similartype(blocktype(a), complex(eltype(a)))
121+
return similar(a, BlockType(B))
122+
end
123+
end
124+
end
125+
126+
function LinearAlgebra.inv(a::AnyAbstractBlockSparseMatrix)
127+
return matrix_function_blocksparse(inv, a)
128+
end
129+
130+
using LinearAlgebra: LinearAlgebra, pinv
131+
function LinearAlgebra.pinv(a::AnyAbstractBlockSparseMatrix; kwargs...)
132+
return matrix_function_blocksparse(pinv, a; kwargs...)
133+
end

test/test_factorizations.jl

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays:
3-
BlockSparseArray, BlockDiagonal, blockstoredlength, eachblockstoredindex
3+
BlockSparseArrays,
4+
BlockDiagonal,
5+
BlockSparseArray,
6+
BlockSparseMatrix,
7+
blockstoredlength,
8+
eachblockstoredindex
9+
using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart, pinv
410
using MatrixAlgebraKit:
511
diagview,
612
eig_full,
@@ -22,10 +28,87 @@ using MatrixAlgebraKit:
2228
svd_trunc,
2329
truncrank,
2430
trunctol
25-
using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart
2631
using Random: Random
2732
using StableRNGs: StableRNG
28-
using Test: @inferred, @testset, @test
33+
using Test: @inferred, @test, @test_broken, @test_throws, @testset
34+
35+
@testset "Matrix functions (T=$elt)" for elt in (Float32, Float64, ComplexF64)
36+
rng = StableRNG(123)
37+
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
38+
a[Block(1, 1)] = randn(rng, elt, 2, 2)
39+
a[Block(2, 2)] = randn(rng, elt, 3, 3)
40+
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
41+
MATRIX_FUNCTIONS = [MATRIX_FUNCTIONS; [:inv, :pinv]]
42+
# Only works when real, also isn't defined in Julia 1.10.
43+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
44+
MATRIX_FUNCTIONS_LOW_ACCURACY = [:acoth]
45+
for f in setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_LOW_ACCURACY)
46+
@eval begin
47+
fa = $f($a)
48+
@test Matrix(fa) $f(Matrix($a)) rtol = (eps(real($elt)))
49+
@test fa isa BlockSparseMatrix
50+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
51+
end
52+
end
53+
for f in MATRIX_FUNCTIONS_LOW_ACCURACY
54+
@eval begin
55+
fa = $f($a)
56+
if !Sys.isapple() && ($elt <: Real)
57+
# `acoth` appears to be broken on this matrix on Windows and Ubuntu
58+
# for real matrices.
59+
@test_broken Matrix(fa) $f(Matrix($a)) rtol = eps(real($elt))
60+
else
61+
@test Matrix(fa) $f(Matrix($a)) rtol = eps(real($elt))
62+
end
63+
@test fa isa BlockSparseMatrix
64+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
65+
end
66+
end
67+
68+
# Catch case of off-diagonal blocks.
69+
rng = StableRNG(123)
70+
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
71+
a[Block(1, 1)] = randn(rng, elt, 2, 2)
72+
a[Block(1, 2)] = randn(rng, elt, 2, 3)
73+
for f in MATRIX_FUNCTIONS
74+
@eval begin
75+
@test_throws ArgumentError $f($a)
76+
end
77+
end
78+
79+
# Missing diagonal blocks.
80+
rng = StableRNG(123)
81+
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
82+
a[Block(2, 2)] = randn(rng, elt, 3, 3)
83+
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
84+
# These functions involve inverses so they break when there are zeros on the diagonal.
85+
MATRIX_FUNCTIONS_SINGULAR = [
86+
:log, :acsc, :asec, :acot, :acsch, :asech, :acoth, :csc, :cot, :csch, :coth
87+
]
88+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_SINGULAR)
89+
# Dense version is broken for some reason, investigate.
90+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
91+
for f in MATRIX_FUNCTIONS
92+
@eval begin
93+
fa = $f($a)
94+
@test Matrix(fa) $f(Matrix($a)) rtol = (eps(real($elt)))
95+
@test fa isa BlockSparseMatrix
96+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
97+
end
98+
end
99+
100+
SINGULAR_EXCEPTION = if VERSION < v"1.11-"
101+
# A different exception is thrown in older versions of Julia.
102+
LinearAlgebra.LAPACKException
103+
else
104+
LinearAlgebra.SingularException
105+
end
106+
for f in setdiff(MATRIX_FUNCTIONS_SINGULAR, [:log])
107+
@eval begin
108+
@test_throws $SINGULAR_EXCEPTION $f($a)
109+
end
110+
end
111+
end
29112

30113
function test_svd(a, (U, S, Vᴴ); full=false)
31114
# Check that the SVD is correct

0 commit comments

Comments
 (0)