diff --git a/Project.toml b/Project.toml index de42c02e..97958cff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.3" +version = "0.7.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractblocksparsearray/linearalgebra.jl b/src/abstractblocksparsearray/linearalgebra.jl index 8121477a..70702f0c 100644 --- a/src/abstractblocksparsearray/linearalgebra.jl +++ b/src/abstractblocksparsearray/linearalgebra.jl @@ -32,3 +32,102 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix) end return tr_a end + +# TODO: Define `SparseArraysBase.isdiag`, define as +# `isdiag(blocks(a))`. +function blockisdiag(a::AbstractArray) + return all(eachblockstoredindex(a)) do I + return allequal(Tuple(I)) + end +end + +const MATRIX_FUNCTIONS = [ + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + +# Functions where the dense implementations in `LinearAlgebra` are +# not type stable. +const MATRIX_FUNCTIONS_UNSTABLE = [ + :log, + :sqrt, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + +function initialize_output_blocksparse(f::F, a::AbstractMatrix) where {F} + B = Base.promote_op(f, blocktype(a)) + return similar(a, BlockType(B)) +end + +function matrix_function_blocksparse(f::F, a::AbstractMatrix; kwargs...) where {F} + blockisdiag(a) || throw(ArgumentError("`$f` only defined for block-diagonal matrices")) + fa = initialize_output_blocksparse(f, a) + for I in blockdiagindices(a) + fa[I] = f(a[I]; kwargs...) + end + return fa +end + +for f in MATRIX_FUNCTIONS + @eval begin + function Base.$f(a::AnyAbstractBlockSparseMatrix) + return matrix_function_blocksparse($f, a) + end + end +end + +for f in MATRIX_FUNCTIONS_UNSTABLE + @eval begin + function initialize_output_blocksparse(::typeof($f), a::AbstractMatrix) + B = similartype(blocktype(a), complex(eltype(a))) + return similar(a, BlockType(B)) + end + end +end + +function LinearAlgebra.inv(a::AnyAbstractBlockSparseMatrix) + return matrix_function_blocksparse(inv, a) +end + +using LinearAlgebra: LinearAlgebra, pinv +function LinearAlgebra.pinv(a::AnyAbstractBlockSparseMatrix; kwargs...) + return matrix_function_blocksparse(pinv, a; kwargs...) +end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 222ae1c1..b934f7e9 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,6 +1,12 @@ using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar using BlockSparseArrays: - BlockSparseArray, BlockDiagonal, blockstoredlength, eachblockstoredindex + BlockSparseArrays, + BlockDiagonal, + BlockSparseArray, + BlockSparseMatrix, + blockstoredlength, + eachblockstoredindex +using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart, pinv using MatrixAlgebraKit: diagview, eig_full, @@ -22,10 +28,87 @@ using MatrixAlgebraKit: svd_trunc, truncrank, trunctol -using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart using Random: Random using StableRNGs: StableRNG -using Test: @inferred, @testset, @test +using Test: @inferred, @test, @test_broken, @test_throws, @testset + +@testset "Matrix functions (T=$elt)" for elt in (Float32, Float64, ComplexF64) + rng = StableRNG(123) + a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3]) + a[Block(1, 1)] = randn(rng, elt, 2, 2) + a[Block(2, 2)] = randn(rng, elt, 3, 3) + MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS + MATRIX_FUNCTIONS = [MATRIX_FUNCTIONS; [:inv, :pinv]] + # Only works when real, also isn't defined in Julia 1.10. + MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt]) + MATRIX_FUNCTIONS_LOW_ACCURACY = [:acoth] + for f in setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_LOW_ACCURACY) + @eval begin + fa = $f($a) + @test Matrix(fa) ≈ $f(Matrix($a)) rtol = √(eps(real($elt))) + @test fa isa BlockSparseMatrix + @test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)]) + end + end + for f in MATRIX_FUNCTIONS_LOW_ACCURACY + @eval begin + fa = $f($a) + if !Sys.isapple() && ($elt <: Real) + # `acoth` appears to be broken on this matrix on Windows and Ubuntu + # for real matrices. + @test_broken Matrix(fa) ≈ $f(Matrix($a)) rtol = √eps(real($elt)) + else + @test Matrix(fa) ≈ $f(Matrix($a)) rtol = √eps(real($elt)) + end + @test fa isa BlockSparseMatrix + @test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)]) + end + end + + # Catch case of off-diagonal blocks. + rng = StableRNG(123) + a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3]) + a[Block(1, 1)] = randn(rng, elt, 2, 2) + a[Block(1, 2)] = randn(rng, elt, 2, 3) + for f in MATRIX_FUNCTIONS + @eval begin + @test_throws ArgumentError $f($a) + end + end + + # Missing diagonal blocks. + rng = StableRNG(123) + a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3]) + a[Block(2, 2)] = randn(rng, elt, 3, 3) + MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS + # These functions involve inverses so they break when there are zeros on the diagonal. + MATRIX_FUNCTIONS_SINGULAR = [ + :log, :acsc, :asec, :acot, :acsch, :asech, :acoth, :csc, :cot, :csch, :coth + ] + MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_SINGULAR) + # Dense version is broken for some reason, investigate. + MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt]) + for f in MATRIX_FUNCTIONS + @eval begin + fa = $f($a) + @test Matrix(fa) ≈ $f(Matrix($a)) rtol = √(eps(real($elt))) + @test fa isa BlockSparseMatrix + @test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)]) + end + end + + SINGULAR_EXCEPTION = if VERSION < v"1.11-" + # A different exception is thrown in older versions of Julia. + LinearAlgebra.LAPACKException + else + LinearAlgebra.SingularException + end + for f in setdiff(MATRIX_FUNCTIONS_SINGULAR, [:log]) + @eval begin + @test_throws $SINGULAR_EXCEPTION $f($a) + end + end +end function test_svd(a, (U, S, Vᴴ); full=false) # Check that the SVD is correct