Skip to content

Commit 5336dfc

Browse files
committed
Start adding support for matrix functions
1 parent 81003b1 commit 5336dfc

File tree

3 files changed

+143
-4
lines changed

3 files changed

+143
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.3"
4+
version = "0.7.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,70 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix)
3232
end
3333
return tr_a
3434
end
35+
36+
# TODO: Define `SparseArraysBase.isdiag`, define as
37+
# `isdiag(blocks(a))`.
38+
function blockisdiag(a::AbstractArray)
39+
return all(eachblockstoredindex(a)) do I
40+
return allequal(Tuple(I))
41+
end
42+
end
43+
44+
const MATRIX_FUNCTIONS = [
45+
:exp,
46+
:cis,
47+
:log,
48+
:sqrt,
49+
:cbrt,
50+
:cos,
51+
:sin,
52+
:tan,
53+
:csc,
54+
:sec,
55+
:cot,
56+
:cosh,
57+
:sinh,
58+
:tanh,
59+
:csch,
60+
:sech,
61+
:coth,
62+
:acos,
63+
:asin,
64+
:atan,
65+
:acsc,
66+
:asec,
67+
:acot,
68+
:acosh,
69+
:asinh,
70+
:atanh,
71+
:acsch,
72+
:asech,
73+
:acoth,
74+
]
75+
76+
function matrix_function_blocksparse(f::F, a::AbstractMatrix; kwargs...) where {F}
77+
blockisdiag(a) || throw(ArgumentError("`$f` only defined for block-diagonal matrices"))
78+
B = Base.promote_op(f, blocktype(a))
79+
fa = similar(a, BlockType(B))
80+
for I in blockdiagindices(a)
81+
fa[I] = f(a[I]; kwargs...)
82+
end
83+
return fa
84+
end
85+
86+
for f in MATRIX_FUNCTIONS
87+
@eval begin
88+
function Base.$f(a::AnyAbstractBlockSparseMatrix)
89+
return matrix_function_blocksparse($f, a)
90+
end
91+
end
92+
end
93+
94+
function LinearAlgebra.inv(a::AnyAbstractBlockSparseMatrix)
95+
return matrix_function_blocksparse(inv, a)
96+
end
97+
98+
using LinearAlgebra: LinearAlgebra, pinv
99+
function LinearAlgebra.pinv(a::AnyAbstractBlockSparseMatrix; kwargs...)
100+
return matrix_function_blocksparse(pinv, a; kwargs...)
101+
end

test/test_factorizations.jl

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays:
3-
BlockSparseArray, BlockDiagonal, blockstoredlength, eachblockstoredindex
3+
BlockSparseArrays,
4+
BlockDiagonal,
5+
BlockSparseArray,
6+
BlockSparseMatrix,
7+
blockstoredlength,
8+
eachblockstoredindex
9+
using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart, pinv
410
using MatrixAlgebraKit:
511
diagview,
612
eig_full,
@@ -22,10 +28,76 @@ using MatrixAlgebraKit:
2228
svd_trunc,
2329
truncrank,
2430
trunctol
25-
using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart
2631
using Random: Random
2732
using StableRNGs: StableRNG
28-
using Test: @inferred, @testset, @test
33+
using Test: @inferred, @test, @test_throws, @testset
34+
35+
# These functions involve inverses so break when there are zeros on the diagonal.
36+
MATRIX_FUNCTIONS_SINGULAR = [:csc, :cot, :csch, :coth]
37+
38+
# Broken because of type stability issues. Fix manually by forcing to be complex.
39+
MATRIX_FUNCTIONS_UNSTABLE = [
40+
:log,
41+
:sqrt,
42+
:acos,
43+
:asin,
44+
:atan,
45+
:acsc,
46+
:asec,
47+
:acot,
48+
:acosh,
49+
:asinh,
50+
:atanh,
51+
:acsch,
52+
:asech,
53+
:acoth,
54+
]
55+
56+
@testset "Matrix functions (eltype=$elt)" for elt in (Float32, Float64, ComplexF64)
57+
rng = StableRNG(123)
58+
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
59+
a[Block(1, 1)] = randn(rng, elt, 2, 2)
60+
a[Block(2, 2)] = randn(rng, elt, 3, 3)
61+
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
62+
MATRIX_FUNCTIONS = [MATRIX_FUNCTIONS; [:inv, :pinv]]
63+
# Only works when real, also isn't defined in Julia 1.10.
64+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
65+
# Broken because of type stability issues. Fix manually by forcing to be complex.
66+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_UNSTABLE)
67+
for f in MATRIX_FUNCTIONS
68+
@eval begin
69+
fa = $f($a)
70+
@test Matrix(fa) $f(Matrix($a))
71+
@test fa isa BlockSparseMatrix
72+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
73+
end
74+
end
75+
76+
# Skip inverse functions when there are missing/zero diagonal blocks.
77+
rng = StableRNG(123)
78+
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
79+
a[Block(2, 2)] = randn(rng, elt, 3, 3)
80+
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
81+
# These functions involve inverses so break when there are zeros on the diagonal.
82+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_SINGULAR)
83+
# Broken because of type stability issues. Fix manually by forcing to be complex.
84+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_UNSTABLE)
85+
# Dense version is broken for some reason, investigate.
86+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
87+
for f in MATRIX_FUNCTIONS
88+
@eval begin
89+
fa = $f($a)
90+
@test Matrix(fa) $f(Matrix($a))
91+
@test fa isa BlockSparseMatrix
92+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
93+
end
94+
end
95+
for f in MATRIX_FUNCTIONS_SINGULAR
96+
@eval begin
97+
@test_throws LinearAlgebra.SingularException $f($a)
98+
end
99+
end
100+
end
29101

30102
function test_svd(a, (U, S, Vᴴ); full=false)
31103
# Check that the SVD is correct

0 commit comments

Comments
 (0)