diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 4c49a86..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options -style = "blue" -indent = 2 diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 456fa05..0614de9 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -2,7 +2,7 @@ name: "CompatHelper" on: schedule: - - cron: 0 0 * * * + - cron: '0 0 * * *' workflow_dispatch: permissions: contents: write diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index 3f78afc..1525861 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,11 +1,14 @@ name: "Format Check" on: - push: - branches: - - 'main' - tags: '*' - pull_request: + pull_request_target: + paths: ['**/*.jl'] + types: [opened, synchronize, reopened, ready_for_review] + +permissions: + contents: read + actions: write + pull-requests: write jobs: format-check: diff --git a/.gitignore b/.gitignore index 10593a9..7085ca8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,10 @@ .vscode/ Manifest.toml benchmark/*.json +dev/ +docs/LocalPreferences.toml docs/Manifest.toml docs/build/ docs/src/index.md +examples/LocalPreferences.toml +test/LocalPreferences.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88bc8b4..3fc4743 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ ci: - skip: [julia-formatter] + skip: [runic] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -11,7 +11,7 @@ repos: - id: end-of-file-fixer exclude_types: [markdown] # incompatible with Literate.jl -- repo: "https://github.com/domluna/JuliaFormatter.jl" - rev: v2.1.6 +- repo: https://github.com/fredrikekre/runic-pre-commit + rev: v2.0.1 hooks: - - id: "julia-formatter" + - id: runic diff --git a/Project.toml b/Project.toml index f016c0d..6b4150a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.21" +version = "0.3.22" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/README.md b/README.md index 50b6f7e..6867869 100644 --- a/README.md +++ b/README.md @@ -45,67 +45,67 @@ julia> Pkg.add("DiagonalArrays") ````julia using DiagonalArrays: - DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex + DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex using Test: @test function main() - d = DiagonalMatrix([1.0, 2.0, 3.0]) - @test eltype(d) == Float64 - @test diaglength(d) == 3 - @test size(d) == (3, 3) - @test d[1, 1] == 1 - @test d[2, 2] == 2 - @test d[3, 3] == 3 - @test d[1, 2] == 0 - - d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5) - @test eltype(d) == Float64 - @test diaglength(d) == 3 - @test d[1, 1, 1] == 1 - @test d[2, 2, 2] == 2 - @test d[3, 3, 3] == 3 - @test d[1, 2, 1] == 0 - - d[2, 2, 2] = 22 - @test d[2, 2, 2] == 22 - - d_r = reshape(d, 3, 20) - @test size(d_r) == (3, 20) - @test all(I -> d_r[I] == d[I], LinearIndices(d)) - - @test length(d[DiagIndices(:)]) == 3 - @test Array(d) == d - @test d[DiagIndex(2)] == d[2, 2, 2] - - d[DiagIndex(2)] = 222 - @test d[2, 2, 2] == 222 - - a = randn(3, 4, 5) - new_diag = randn(3) - a[DiagIndices(:)] = new_diag - d[DiagIndices(:)] = a[DiagIndices(:)] - - @test a[DiagIndices(:)] == new_diag - @test d[DiagIndices(:)] == new_diag - - permuted_d = permutedims(d, (3, 2, 1)) - @test permuted_d isa DiagonalArray - @test permuted_d[DiagIndices(:)] == d[DiagIndices(:)] - @test size(d) == (3, 4, 5) - @test size(permuted_d) == (5, 4, 3) - for I in eachindex(d) - if !isdiagindex(d, I) - @test iszero(d[I]) - else - @test !iszero(d[I]) + d = DiagonalMatrix([1.0, 2.0, 3.0]) + @test eltype(d) == Float64 + @test diaglength(d) == 3 + @test size(d) == (3, 3) + @test d[1, 1] == 1 + @test d[2, 2] == 2 + @test d[3, 3] == 3 + @test d[1, 2] == 0 + + d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5) + @test eltype(d) == Float64 + @test diaglength(d) == 3 + @test d[1, 1, 1] == 1 + @test d[2, 2, 2] == 2 + @test d[3, 3, 3] == 3 + @test d[1, 2, 1] == 0 + + d[2, 2, 2] = 22 + @test d[2, 2, 2] == 22 + + d_r = reshape(d, 3, 20) + @test size(d_r) == (3, 20) + @test all(I -> d_r[I] == d[I], LinearIndices(d)) + + @test length(d[DiagIndices(:)]) == 3 + @test Array(d) == d + @test d[DiagIndex(2)] == d[2, 2, 2] + + d[DiagIndex(2)] = 222 + @test d[2, 2, 2] == 222 + + a = randn(3, 4, 5) + new_diag = randn(3) + a[DiagIndices(:)] = new_diag + d[DiagIndices(:)] = a[DiagIndices(:)] + + @test a[DiagIndices(:)] == new_diag + @test d[DiagIndices(:)] == new_diag + + permuted_d = permutedims(d, (3, 2, 1)) + @test permuted_d isa DiagonalArray + @test permuted_d[DiagIndices(:)] == d[DiagIndices(:)] + @test size(d) == (3, 4, 5) + @test size(permuted_d) == (5, 4, 3) + for I in eachindex(d) + if !isdiagindex(d, I) + @test iszero(d[I]) + else + @test !iszero(d[I]) + end end - end - mapped_d = map(x -> 2x, d) - @test mapped_d isa DiagonalArray - @test mapped_d == map(x -> 2x, Array(d)) + mapped_d = map(x -> 2x, d) + @test mapped_d isa DiagonalArray + @test mapped_d == map(x -> 2x, Array(d)) - return nothing + return nothing end main() diff --git a/docs/make.jl b/docs/make.jl index 52dc9d1..da641ab 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,22 +1,22 @@ using DiagonalArrays: DiagonalArrays using Documenter: Documenter, DocMeta, deploydocs, makedocs -DocMeta.setdocmeta!(DiagonalArrays, :DocTestSetup, :(using DiagonalArrays); recursive=true) +DocMeta.setdocmeta!(DiagonalArrays, :DocTestSetup, :(using DiagonalArrays); recursive = true) include("make_index.jl") makedocs(; - modules=[DiagonalArrays], - authors="ITensor developers and contributors", - sitename="DiagonalArrays.jl", - format=Documenter.HTML(; - canonical="https://itensor.github.io/DiagonalArrays.jl", - edit_link="main", - assets=["assets/favicon.ico", "assets/extras.css"], - ), - pages=["Home" => "index.md", "Reference" => "reference.md"], + modules = [DiagonalArrays], + authors = "ITensor developers and contributors", + sitename = "DiagonalArrays.jl", + format = Documenter.HTML(; + canonical = "https://itensor.github.io/DiagonalArrays.jl", + edit_link = "main", + assets = ["assets/favicon.ico", "assets/extras.css"], + ), + pages = ["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; - repo="github.com/ITensor/DiagonalArrays.jl", devbranch="main", push_preview=true + repo = "github.com/ITensor/DiagonalArrays.jl", devbranch = "main", push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index b19798f..2b7321b 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -2,20 +2,20 @@ using Literate: Literate using DiagonalArrays: DiagonalArrays function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ ```@raw html Flatiron Center for Computational Quantum Physics logo. Flatiron Center for Computational Quantum Physics logo. ``` """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(DiagonalArrays), "examples", "README.jl"), - joinpath(pkgdir(DiagonalArrays), "docs", "src"); - flavor=Literate.DocumenterFlavor(), - name="index", - postprocess=ccq_logo, + joinpath(pkgdir(DiagonalArrays), "examples", "README.jl"), + joinpath(pkgdir(DiagonalArrays), "docs", "src"); + flavor = Literate.DocumenterFlavor(), + name = "index", + postprocess = ccq_logo, ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index 0c306d0..8309caa 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -2,20 +2,20 @@ using Literate: Literate using DiagonalArrays: DiagonalArrays function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ Flatiron Center for Computational Quantum Physics logo. """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(DiagonalArrays), "examples", "README.jl"), - joinpath(pkgdir(DiagonalArrays)); - flavor=Literate.CommonMarkFlavor(), - name="README", - postprocess=ccq_logo, + joinpath(pkgdir(DiagonalArrays), "examples", "README.jl"), + joinpath(pkgdir(DiagonalArrays)); + flavor = Literate.CommonMarkFlavor(), + name = "README", + postprocess = ccq_logo, ) diff --git a/examples/README.jl b/examples/README.jl index 1678f75..bb6efdc 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -1,5 +1,5 @@ # # DiagonalArrays.jl -# +# # [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/DiagonalArrays.jl/stable/) # [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/DiagonalArrays.jl/dev/) # [![Build Status](https://github.com/ITensor/DiagonalArrays.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/DiagonalArrays.jl/actions/workflows/Tests.yml?query=branch%3Amain) @@ -46,67 +46,67 @@ julia> Pkg.add("DiagonalArrays") # ## Examples using DiagonalArrays: - DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex + DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex using Test: @test function main() - d = DiagonalMatrix([1.0, 2.0, 3.0]) - @test eltype(d) == Float64 - @test diaglength(d) == 3 - @test size(d) == (3, 3) - @test d[1, 1] == 1 - @test d[2, 2] == 2 - @test d[3, 3] == 3 - @test d[1, 2] == 0 - - d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5) - @test eltype(d) == Float64 - @test diaglength(d) == 3 - @test d[1, 1, 1] == 1 - @test d[2, 2, 2] == 2 - @test d[3, 3, 3] == 3 - @test d[1, 2, 1] == 0 - - d[2, 2, 2] = 22 - @test d[2, 2, 2] == 22 - - d_r = reshape(d, 3, 20) - @test size(d_r) == (3, 20) - @test all(I -> d_r[I] == d[I], LinearIndices(d)) - - @test length(d[DiagIndices(:)]) == 3 - @test Array(d) == d - @test d[DiagIndex(2)] == d[2, 2, 2] - - d[DiagIndex(2)] = 222 - @test d[2, 2, 2] == 222 - - a = randn(3, 4, 5) - new_diag = randn(3) - a[DiagIndices(:)] = new_diag - d[DiagIndices(:)] = a[DiagIndices(:)] - - @test a[DiagIndices(:)] == new_diag - @test d[DiagIndices(:)] == new_diag - - permuted_d = permutedims(d, (3, 2, 1)) - @test permuted_d isa DiagonalArray - @test permuted_d[DiagIndices(:)] == d[DiagIndices(:)] - @test size(d) == (3, 4, 5) - @test size(permuted_d) == (5, 4, 3) - for I in eachindex(d) - if !isdiagindex(d, I) - @test iszero(d[I]) - else - @test !iszero(d[I]) + d = DiagonalMatrix([1.0, 2.0, 3.0]) + @test eltype(d) == Float64 + @test diaglength(d) == 3 + @test size(d) == (3, 3) + @test d[1, 1] == 1 + @test d[2, 2] == 2 + @test d[3, 3] == 3 + @test d[1, 2] == 0 + + d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5) + @test eltype(d) == Float64 + @test diaglength(d) == 3 + @test d[1, 1, 1] == 1 + @test d[2, 2, 2] == 2 + @test d[3, 3, 3] == 3 + @test d[1, 2, 1] == 0 + + d[2, 2, 2] = 22 + @test d[2, 2, 2] == 22 + + d_r = reshape(d, 3, 20) + @test size(d_r) == (3, 20) + @test all(I -> d_r[I] == d[I], LinearIndices(d)) + + @test length(d[DiagIndices(:)]) == 3 + @test Array(d) == d + @test d[DiagIndex(2)] == d[2, 2, 2] + + d[DiagIndex(2)] = 222 + @test d[2, 2, 2] == 222 + + a = randn(3, 4, 5) + new_diag = randn(3) + a[DiagIndices(:)] = new_diag + d[DiagIndices(:)] = a[DiagIndices(:)] + + @test a[DiagIndices(:)] == new_diag + @test d[DiagIndices(:)] == new_diag + + permuted_d = permutedims(d, (3, 2, 1)) + @test permuted_d isa DiagonalArray + @test permuted_d[DiagIndices(:)] == d[DiagIndices(:)] + @test size(d) == (3, 4, 5) + @test size(permuted_d) == (5, 4, 3) + for I in eachindex(d) + if !isdiagindex(d, I) + @test iszero(d[I]) + else + @test !iszero(d[I]) + end end - end - mapped_d = map(x -> 2x, d) - @test mapped_d isa DiagonalArray - @test mapped_d == map(x -> 2x, Array(d)) + mapped_d = map(x -> 2x, d) + @test mapped_d isa DiagonalArray + @test mapped_d == map(x -> 2x, Array(d)) - return nothing + return nothing end main() diff --git a/ext/DiagonalArraysMatrixAlgebraKitExt/DiagonalArraysMatrixAlgebraKitExt.jl b/ext/DiagonalArraysMatrixAlgebraKitExt/DiagonalArraysMatrixAlgebraKitExt.jl index 3f94c39..e0bbed4 100644 --- a/ext/DiagonalArraysMatrixAlgebraKitExt/DiagonalArraysMatrixAlgebraKitExt.jl +++ b/ext/DiagonalArraysMatrixAlgebraKitExt/DiagonalArraysMatrixAlgebraKitExt.jl @@ -1,313 +1,313 @@ module DiagonalArraysMatrixAlgebraKitExt using DiagonalArrays: - AbstractDiagonalMatrix, - DeltaMatrix, - DiagonalMatrix, - ScaledDeltaMatrix, - δ, - diagview, - dual, - issquare + AbstractDiagonalMatrix, + DeltaMatrix, + DiagonalMatrix, + ScaledDeltaMatrix, + δ, + diagview, + dual, + issquare using LinearAlgebra: LinearAlgebra, isdiag, ishermitian using MatrixAlgebraKit: - MatrixAlgebraKit, - AbstractAlgorithm, - check_input, - default_qr_algorithm, - eig_full, - eig_full!, - eig_vals, - eig_vals!, - eigh_full, - eigh_full!, - eigh_vals, - eigh_vals!, - left_null, - left_null!, - left_orth, - left_orth!, - left_polar, - left_polar!, - lq_compact, - lq_compact!, - lq_full, - lq_full!, - qr_compact, - qr_compact!, - qr_full, - qr_full!, - right_null, - right_null!, - right_orth, - right_orth!, - right_polar, - right_polar!, - svd_compact, - svd_compact!, - svd_full, - svd_full!, - svd_vals, - svd_vals! + MatrixAlgebraKit, + AbstractAlgorithm, + check_input, + default_qr_algorithm, + eig_full, + eig_full!, + eig_vals, + eig_vals!, + eigh_full, + eigh_full!, + eigh_vals, + eigh_vals!, + left_null, + left_null!, + left_orth, + left_orth!, + left_polar, + left_polar!, + lq_compact, + lq_compact!, + lq_full, + lq_full!, + qr_compact, + qr_compact!, + qr_full, + qr_full!, + right_null, + right_null!, + right_orth, + right_orth!, + right_polar, + right_polar!, + svd_compact, + svd_compact!, + svd_full, + svd_full!, + svd_vals, + svd_vals! abstract type AbstractDiagonalAlgorithm <: AbstractAlgorithm end -struct DeltaAlgorithm{KWargs<:NamedTuple} <: AbstractDiagonalAlgorithm - kwargs::KWargs +struct DeltaAlgorithm{KWargs <: NamedTuple} <: AbstractDiagonalAlgorithm + kwargs::KWargs end DeltaAlgorithm(; kwargs...) = DeltaAlgorithm((; kwargs...)) -struct ScaledDeltaAlgorithm{KWargs<:NamedTuple} <: AbstractDiagonalAlgorithm - kwargs::KWargs +struct ScaledDeltaAlgorithm{KWargs <: NamedTuple} <: AbstractDiagonalAlgorithm + kwargs::KWargs end ScaledDeltaAlgorithm(; kwargs...) = ScaledDeltaAlgorithm((; kwargs...)) for f in [ - :eig_full, - :eig_vals, - :eigh_full, - :eigh_vals, - :qr_compact, - :qr_full, - :left_null, - :left_orth, - :left_polar, - :lq_compact, - :lq_full, - :right_null, - :right_orth, - :right_polar, - :svd_compact, - :svd_full, - :svd_vals, -] - @eval begin - MatrixAlgebraKit.copy_input(::typeof($f), a::AbstractDiagonalMatrix) = copy(a) - end + :eig_full, + :eig_vals, + :eigh_full, + :eigh_vals, + :qr_compact, + :qr_full, + :left_null, + :left_orth, + :left_polar, + :lq_compact, + :lq_full, + :right_null, + :right_orth, + :right_polar, + :svd_compact, + :svd_full, + :svd_vals, + ] + @eval begin + MatrixAlgebraKit.copy_input(::typeof($f), a::AbstractDiagonalMatrix) = copy(a) + end end for f in [ - :default_eig_algorithm, - :default_eigh_algorithm, - :default_lq_algorithm, - :default_qr_algorithm, - :default_polar_algorithm, - :default_svd_algorithm, -] - @eval begin - function MatrixAlgebraKit.$f(::Type{<:DeltaMatrix}; kwargs...) - return DeltaAlgorithm(; kwargs...) - end - function MatrixAlgebraKit.$f(::Type{<:ScaledDeltaMatrix}; kwargs...) - return ScaledDeltaAlgorithm(; kwargs...) + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, + ] + @eval begin + function MatrixAlgebraKit.$f(::Type{<:DeltaMatrix}; kwargs...) + return DeltaAlgorithm(; kwargs...) + end + function MatrixAlgebraKit.$f(::Type{<:ScaledDeltaMatrix}; kwargs...) + return ScaledDeltaAlgorithm(; kwargs...) + end end - end end for f in [ - :eig_full!, - :eig_vals!, - :eigh_full!, - :eigh_vals!, - :left_null!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :qr_compact!, - :qr_full!, - :right_null!, - :right_orth!, - :right_polar!, - :svd_compact!, - :svd_full!, - :svd_vals!, -] - for Alg in [:ScaledDeltaAlgorithm, :DeltaAlgorithm] - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::AbstractMatrix, alg::$Alg - ) - return nothing - end + :eig_full!, + :eig_vals!, + :eigh_full!, + :eigh_vals!, + :left_null!, + :left_orth!, + :left_polar!, + :lq_compact!, + :lq_full!, + :qr_compact!, + :qr_full!, + :right_null!, + :right_orth!, + :right_polar!, + :svd_compact!, + :svd_full!, + :svd_vals!, + ] + for Alg in [:ScaledDeltaAlgorithm, :DeltaAlgorithm] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::AbstractMatrix, alg::$Alg + ) + return nothing + end + end end - end end for f in [ - :left_null!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :qr_compact!, - :qr_full!, - :right_null!, - :right_orth!, - :right_polar!, - :svd_compact!, - :svd_full!, - :svd_vals!, -] - @eval begin - function MatrixAlgebraKit.check_input( - ::typeof($f), a::AbstractMatrix, F, alg::DeltaAlgorithm - ) - @assert size(a, 1) == size(a, 2) - @assert isdiag(a) - @assert all(isone, diagview(a)) - return nothing - end - function MatrixAlgebraKit.check_input( - ::typeof($f), a::AbstractMatrix, F, alg::ScaledDeltaAlgorithm - ) - @assert size(a, 1) == size(a, 2) - @assert isdiag(a) - @assert allequal(diagview(a)) - return nothing + :left_null!, + :left_orth!, + :left_polar!, + :lq_compact!, + :lq_full!, + :qr_compact!, + :qr_full!, + :right_null!, + :right_orth!, + :right_polar!, + :svd_compact!, + :svd_full!, + :svd_vals!, + ] + @eval begin + function MatrixAlgebraKit.check_input( + ::typeof($f), a::AbstractMatrix, F, alg::DeltaAlgorithm + ) + @assert size(a, 1) == size(a, 2) + @assert isdiag(a) + @assert all(isone, diagview(a)) + return nothing + end + function MatrixAlgebraKit.check_input( + ::typeof($f), a::AbstractMatrix, F, alg::ScaledDeltaAlgorithm + ) + @assert size(a, 1) == size(a, 2) + @assert isdiag(a) + @assert allequal(diagview(a)) + return nothing + end end - end end for f in [:eig_full!, :eig_vals!, :eigh_full!, :eigh_vals!] - @eval begin - function MatrixAlgebraKit.check_input( - ::typeof($f), a::AbstractMatrix, F, alg::DeltaAlgorithm - ) - @assert issquare(a) - @assert isdiag(a) - @assert all(isone, diagview(a)) - return nothing - end - function MatrixAlgebraKit.check_input( - ::typeof($f), a::AbstractMatrix, F, alg::ScaledDeltaAlgorithm - ) - @assert issquare(a) - @assert isdiag(a) - @assert allequal(diagview(a)) - return nothing + @eval begin + function MatrixAlgebraKit.check_input( + ::typeof($f), a::AbstractMatrix, F, alg::DeltaAlgorithm + ) + @assert issquare(a) + @assert isdiag(a) + @assert all(isone, diagview(a)) + return nothing + end + function MatrixAlgebraKit.check_input( + ::typeof($f), a::AbstractMatrix, F, alg::ScaledDeltaAlgorithm + ) + @assert issquare(a) + @assert isdiag(a) + @assert allequal(diagview(a)) + return nothing + end end - end end # eig for Alg in [:DeltaAlgorithm, :ScaledDeltaAlgorithm] - @eval begin - function MatrixAlgebraKit.eig_full!(a, F, alg::$Alg) - check_input(eig_full!, a, F, alg) - d = complex(a) - v = δ(complex(eltype(a)), axes(a)) - return (d, v) - end - function MatrixAlgebraKit.eigh_full!(a, F, alg::$Alg) - check_input(eigh_full!, a, F, alg) - ishermitian(a) || throw(ArgumentError("Matrix must be Hermitian")) - d = real(a) - v = δ(eltype(a), axes(a)) - return (d, v) - end - function MatrixAlgebraKit.eig_vals!(a, F, alg::$Alg) - check_input(eig_vals!, a, F, alg) - return complex(diagview(a)) - end - function MatrixAlgebraKit.eigh_vals!(a, F, alg::$Alg) - check_input(eigh_vals!, a, F, alg) - return real(diagview(a)) + @eval begin + function MatrixAlgebraKit.eig_full!(a, F, alg::$Alg) + check_input(eig_full!, a, F, alg) + d = complex(a) + v = δ(complex(eltype(a)), axes(a)) + return (d, v) + end + function MatrixAlgebraKit.eigh_full!(a, F, alg::$Alg) + check_input(eigh_full!, a, F, alg) + ishermitian(a) || throw(ArgumentError("Matrix must be Hermitian")) + d = real(a) + v = δ(eltype(a), axes(a)) + return (d, v) + end + function MatrixAlgebraKit.eig_vals!(a, F, alg::$Alg) + check_input(eig_vals!, a, F, alg) + return complex(diagview(a)) + end + function MatrixAlgebraKit.eigh_vals!(a, F, alg::$Alg) + check_input(eigh_vals!, a, F, alg) + return real(diagview(a)) + end end - end end # svd for f in [:svd_compact!, :svd_full!] - @eval begin - function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm) - check_input($f, a, F, alg) - u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1)))) - s = real(a) - v = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2))) - return (u, s, v) - end - function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm) - check_input($f, a, F, alg) - diagvalue = only(unique(diagview(a))) - u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1)))) - s = abs(diagvalue) * δ(Bool, axes(a)) - # Sign is applied arbitarily to `v`, alternatively - # we could apply it to `u`. - v = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2))) - return (u, s, v) + @eval begin + function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm) + check_input($f, a, F, alg) + u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1)))) + s = real(a) + v = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2))) + return (u, s, v) + end + function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm) + check_input($f, a, F, alg) + diagvalue = only(unique(diagview(a))) + u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1)))) + s = abs(diagvalue) * δ(Bool, axes(a)) + # Sign is applied arbitarily to `v`, alternatively + # we could apply it to `u`. + v = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2))) + return (u, s, v) + end end - end end function MatrixAlgebraKit.svd_vals!(a, F, alg::DeltaAlgorithm) - check_input(svd_vals!, a, F, alg) - # Using `real` instead of `abs.` helps to preserve `Ones`. - return real(diagview(a)) + check_input(svd_vals!, a, F, alg) + # Using `real` instead of `abs.` helps to preserve `Ones`. + return real(diagview(a)) end function MatrixAlgebraKit.svd_vals!(a, F, alg::ScaledDeltaAlgorithm) - check_input(svd_vals!, a, F, alg) - return abs.(diagview(a)) + check_input(svd_vals!, a, F, alg) + return abs.(diagview(a)) end # orth # left_orth is implicitly defined by defining backends like # qr_compact and left_polar. for f in [:left_polar!, :qr_compact!, :qr_full!] - @eval begin - function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm) - check_input($f, a, F, alg) - q = δ(eltype(a), (axes(a, 1), dual(axes(a, 1)))) - r = copy(a) - return (q, r) - end - function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm) - check_input($f, a, F, alg) - diagvalue = only(unique(diagview(a))) - q = sign(diagvalue) * δ(Bool, (axes(a, 1), dual(axes(a, 1)))) - # We're a bit pessimistic about the element type for type stability, - # since in the future we might provide the option to do non-positive QR. - r = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a)) - return (q, r) + @eval begin + function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm) + check_input($f, a, F, alg) + q = δ(eltype(a), (axes(a, 1), dual(axes(a, 1)))) + r = copy(a) + return (q, r) + end + function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm) + check_input($f, a, F, alg) + diagvalue = only(unique(diagview(a))) + q = sign(diagvalue) * δ(Bool, (axes(a, 1), dual(axes(a, 1)))) + # We're a bit pessimistic about the element type for type stability, + # since in the future we might provide the option to do non-positive QR. + r = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a)) + return (q, r) + end end - end end # right_orth is implicitly defined by defining backends like # lq_compact and right_polar. for f in [:right_polar!, :lq_compact!, :lq_full!] - @eval begin - function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm) - check_input($f, a, F, alg) - l = copy(a) - q = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2))) - return (l, q) - end - function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm) - check_input($f, a, F, alg) - diagvalue = only(unique(diagview(a))) - # We're a bit pessimistic about the element type for type stability, - # since in the future we might provide the option to do non-positive LQ. - l = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a)) - q = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2))) - return (l, q) + @eval begin + function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm) + check_input($f, a, F, alg) + l = copy(a) + q = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2))) + return (l, q) + end + function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm) + check_input($f, a, F, alg) + diagvalue = only(unique(diagview(a))) + # We're a bit pessimistic about the element type for type stability, + # since in the future we might provide the option to do non-positive LQ. + l = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a)) + q = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2))) + return (l, q) + end end - end end # null for T in [:DeltaMatrix, :ScaledDeltaMatrix] - @eval begin - # TODO: Right now we can't overload `left_null!` on an algorithm, - # make a PR to MatrixAlgebraKit.jl to allow that. - function MatrixAlgebraKit.left_null!(a::$T, F) - check_input(left_null!, a, F, default_qr_algorithm(a)) - return error("Not implemented.") - end - # TODO: Right now we can't overload `right_null!` on an algorithm, - # make a PR to MatrixAlgebraKit.jl to allow that. - function MatrixAlgebraKit.right_null!(a::$T, F) - check_input(right_null!, a, F, default_qr_algorithm(a)) - return error("Not implemented.") + @eval begin + # TODO: Right now we can't overload `left_null!` on an algorithm, + # make a PR to MatrixAlgebraKit.jl to allow that. + function MatrixAlgebraKit.left_null!(a::$T, F) + check_input(left_null!, a, F, default_qr_algorithm(a)) + return error("Not implemented.") + end + # TODO: Right now we can't overload `right_null!` on an algorithm, + # make a PR to MatrixAlgebraKit.jl to allow that. + function MatrixAlgebraKit.right_null!(a::$T, F) + check_input(right_null!, a, F, default_qr_algorithm(a)) + return error("Not implemented.") + end end - end end end diff --git a/src/abstractdiagonalarray/abstractdiagonalarray.jl b/src/abstractdiagonalarray/abstractdiagonalarray.jl index 6fca242..32e30c2 100644 --- a/src/abstractdiagonalarray/abstractdiagonalarray.jl +++ b/src/abstractdiagonalarray/abstractdiagonalarray.jl @@ -1,14 +1,14 @@ using SparseArraysBase: AbstractSparseArray -abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end -const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2} -const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1} +abstract type AbstractDiagonalArray{T, N} <: AbstractSparseArray{T, N} end +const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T, 2} +const AbstractDiagonalVector{T} = AbstractDiagonalArray{T, 1} # Define for type stability, for some reason the generic versions # in SparseArraysBase.jl is not type stable. # TODO: Investigate type stability of `iszero` in SparseArraysBase.jl. function Base.iszero(a::AbstractDiagonalArray) - return iszero(diagview(a)) + return iszero(diagview(a)) end using FillArrays: AbstractFill, getindex_value @@ -16,30 +16,30 @@ using LinearAlgebra: norm # TODO: `_norm` works around: # https://github.com/JuliaArrays/FillArrays.jl/issues/417 # Change back to `norm` when that is fixed. -_norm(a, p::Int=2) = norm(a, p) -function _norm(a::AbstractFill, p::Int=2) - nrm1 = norm(getindex_value(a)) - return (length(a))^(1/oftype(nrm1, p)) * nrm1 +_norm(a, p::Int = 2) = norm(a, p) +function _norm(a::AbstractFill, p::Int = 2) + nrm1 = norm(getindex_value(a)) + return (length(a))^(1 / oftype(nrm1, p)) * nrm1 end -function LinearAlgebra.norm(a::AbstractDiagonalArray, p::Int=2) - # TODO: `_norm` works around: - # https://github.com/JuliaArrays/FillArrays.jl/issues/417 - # Change back to `norm` when that is fixed. - return _norm(diagview(a), p) +function LinearAlgebra.norm(a::AbstractDiagonalArray, p::Int = 2) + # TODO: `_norm` works around: + # https://github.com/JuliaArrays/FillArrays.jl/issues/417 + # Change back to `norm` when that is fixed. + return _norm(diagview(a), p) end using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a) function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number}) - return issquare(a) && isreal(diagview(a)) + return issquare(a) && isreal(diagview(a)) end function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix) - return issquare(a) && all(ishermitian, diagview(a)) + return issquare(a) && all(ishermitian, diagview(a)) end LinearAlgebra.issymmetric(a::AbstractDiagonalMatrix{<:Number}) = issquare(a) function LinearAlgebra.issymmetric(a::AbstractDiagonalMatrix) - return issquare(a) && all(issymmetric, diagview(a)) + return issquare(a) && all(issymmetric, diagview(a)) end function LinearAlgebra.isposdef(a::AbstractDiagonalMatrix) - return issquare(a) && all(isposdef, diagview(a)) + return issquare(a) && all(isposdef, diagview(a)) end diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index d9f203d..40b2cf9 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -4,88 +4,88 @@ diagview(a::AbstractDiagonalArray) = throw(MethodError(diagview, Tuple{typeof(a) using DerivableInterfaces: DerivableInterfaces, @interface using SparseArraysBase: - SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle + SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle abstract type AbstractDiagonalArrayInterface{N} <: AbstractSparseArrayInterface{N} end struct DiagonalArrayInterface{N} <: AbstractDiagonalArrayInterface{N} end -DiagonalArrayInterface{M}(::Val{N}) where {M,N} = DiagonalArrayInterface{N}() +DiagonalArrayInterface{M}(::Val{N}) where {M, N} = DiagonalArrayInterface{N}() DiagionalArrayInterface(::Val{N}) where {N} = DiagonalArrayInterface{N}() DiagonalArrayInterface() = DiagonalArrayInterface{Any}() function Base.similar(::AbstractDiagonalArrayInterface, elt::Type, ax::Tuple) - return similar(DiagonalArray{elt}, ax) + return similar(DiagonalArray{elt}, ax) end -function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArray{<:Any,N}}) where {N} - return DiagonalArrayInterface{N}() +function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArray{<:Any, N}}) where {N} + return DiagonalArrayInterface{N}() end abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArrayStyle{N}}) where {N} - return DiagonalArrayInterface{N}() + return DiagonalArrayInterface{N}() end struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end -DiagonalArrayStyle{M}(::Val{N}) where {M,N} = DiagonalArrayStyle{N}() +DiagonalArrayStyle{M}(::Val{N}) where {M, N} = DiagonalArrayStyle{N}() function SparseArraysBase.isstored( - a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return allequal(I) + a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N} + ) where {N} + return allequal(I) end function SparseArraysBase.getstoredindex( - a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - # TODO: Make this check optional, define `checkstored` like `checkbounds` - # in SparseArraysBase.jl. - # allequal(I) || error("Not a diagonal index.") - return getdiagindex(a, first(I)) + a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N} + ) where {N} + # TODO: Make this check optional, define `checkstored` like `checkbounds` + # in SparseArraysBase.jl. + # allequal(I) || error("Not a diagonal index.") + return getdiagindex(a, first(I)) end -function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any,0}) - return getdiagindex(a, 1) +function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any, 0}) + return getdiagindex(a, 1) end function SparseArraysBase.setstoredindex!( - a::AbstractDiagonalArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - # TODO: Make this check optional, define `checkstored` like `checkbounds` - # in SparseArraysBase.jl. - # allequal(I) || error("Not a diagonal index.") - setdiagindex!(a, value, first(I)) - return a -end -function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any,0}, value) - setdiagindex!(a, value, 1) - return a + a::AbstractDiagonalArray{<:Any, N}, value, I::Vararg{Int, N} + ) where {N} + # TODO: Make this check optional, define `checkstored` like `checkbounds` + # in SparseArraysBase.jl. + # allequal(I) || error("Not a diagonal index.") + setdiagindex!(a, value, first(I)) + return a +end +function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any, 0}, value) + setdiagindex!(a, value, 1) + return a end function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray) - return diagindices(a) + return diagindices(a) end # Fix ambiguity error with SparseArraysBase. function Base.getindex(a::AbstractDiagonalArray, I::DiagIndices) - # TODO: Use `@interface` rather than `invoke`. - return invoke(getindex, Tuple{AbstractArray,DiagIndices}, a, I) + # TODO: Use `@interface` rather than `invoke`. + return invoke(getindex, Tuple{AbstractArray, DiagIndices}, a, I) end # Fix ambiguity error with SparseArraysBase. function Base.getindex(a::AbstractDiagonalArray, I::DiagIndex) - # TODO: Use `@interface` rather than `invoke`. - return invoke(getindex, Tuple{AbstractArray,DiagIndex}, a, I) + # TODO: Use `@interface` rather than `invoke`. + return invoke(getindex, Tuple{AbstractArray, DiagIndex}, a, I) end # Fix ambiguity error with SparseArraysBase. function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndices) - # TODO: Use `@interface` rather than `invoke`. - return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndices}, a, value, I) + # TODO: Use `@interface` rather than `invoke`. + return invoke(setindex!, Tuple{AbstractArray, Any, DiagIndices}, a, value, I) end # Fix ambiguity error with SparseArraysBase. function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex) - # TODO: Use `@interface` rather than `invoke`. - return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndex}, a, value, I) + # TODO: Use `@interface` rather than `invoke`. + return invoke(setindex!, Tuple{AbstractArray, Any, DiagIndex}, a, value, I) end @interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type) - return DiagonalArrayStyle{ndims(type)}() + return DiagonalArrayStyle{ndims(type)}() end using Base.Broadcast: Broadcasted, broadcasted @@ -93,16 +93,16 @@ using MapBroadcast: Mapped # Map to a flattened broadcast expression of the diagonals of the arrays, # also checking that the function preserves zeros. function broadcasted_diagview(bc::Broadcasted) - m = Mapped(bc) - iszero(m.f(map(zero ∘ eltype, m.args)...)) || error( - "Broadcasting DiagonalArrays with function that doesn't preserve zeros isn't supported yet.", - ) - return broadcasted(m.f, map(diagview, m.args)...) + m = Mapped(bc) + iszero(m.f(map(zero ∘ eltype, m.args)...)) || error( + "Broadcasting DiagonalArrays with function that doesn't preserve zeros isn't supported yet.", + ) + return broadcasted(m.f, map(diagview, m.args)...) end function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle}) - return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc)) + return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc)) end function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle}) - copyto!(diagview(dest), broadcasted_diagview(bc)) - return dest + copyto!(diagview(dest), broadcasted_diagview(bc)) + return dest end diff --git a/src/abstractdiagonalarray/sparsearrayinterface.jl b/src/abstractdiagonalarray/sparsearrayinterface.jl index 89908c2..3f47a60 100644 --- a/src/abstractdiagonalarray/sparsearrayinterface.jl +++ b/src/abstractdiagonalarray/sparsearrayinterface.jl @@ -5,7 +5,7 @@ ## !allequal(Tuple(I)) && return nothing ## return first(Tuple(I)) ## end -## +## ## function SparseArraysBase.storage_index_to_index(a::AbstractDiagonalArray, I) ## return CartesianIndex(ntuple(Returns(I), ndims(a))) ## end diff --git a/src/diaginterface/diagindex.jl b/src/diaginterface/diagindex.jl index e177e80..67970af 100644 --- a/src/diaginterface/diagindex.jl +++ b/src/diaginterface/diagindex.jl @@ -1,14 +1,14 @@ # Represents a linear offset along the diagonal struct DiagIndex{I} - i::I + i::I end index(i::DiagIndex) = i.i function Base.getindex(a::AbstractArray, i::DiagIndex) - return getdiagindex(a, index(i)) + return getdiagindex(a, index(i)) end function Base.setindex!(a::AbstractArray, value, i::DiagIndex) - setdiagindex!(a, value, index(i)) - return a + setdiagindex!(a, value, index(i)) + return a end diff --git a/src/diaginterface/diagindices.jl b/src/diaginterface/diagindices.jl index 08590d1..986e72d 100644 --- a/src/diaginterface/diagindices.jl +++ b/src/diaginterface/diagindices.jl @@ -1,14 +1,14 @@ # Represents a set of linear offsets along the diagonal struct DiagIndices{I} - i::I + i::I end indices(i::DiagIndices) = i.i function Base.getindex(a::AbstractArray, I::DiagIndices) - return getdiagindices(a, indices(I)) + return getdiagindices(a, indices(I)) end function Base.setindex!(a::AbstractArray, value, i::DiagIndices) - setdiagindices!(a, value, indices(i)) - return a + setdiagindices!(a, value, indices(i)) + return a end diff --git a/src/diaginterface/diaginterface.jl b/src/diaginterface/diaginterface.jl index dcbc331..24fba66 100644 --- a/src/diaginterface/diaginterface.jl +++ b/src/diaginterface/diaginterface.jl @@ -2,25 +2,25 @@ using LinearAlgebra: LinearAlgebra -diaglength(a::AbstractArray{<:Any,0}) = 1 +diaglength(a::AbstractArray{<:Any, 0}) = 1 function diaglength(a::AbstractArray) - return minimum(size(a)) + return minimum(size(a)) end -@inline function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N} - @boundscheck checkbounds(a, I) - return allequal(Tuple(I)) +@inline function isdiagindex(a::AbstractArray{<:Any, N}, I::CartesianIndex{N}) where {N} + @boundscheck checkbounds(a, I) + return allequal(Tuple(I)) end function diagstride(a::AbstractArray) - s = 1 - p = 1 - for i in 1:(ndims(a) - 1) - p *= size(a, i) - s += p - end - return s + s = 1 + p = 1 + for i in 1:(ndims(a) - 1) + p *= size(a, i) + s += p + end + return s end # Iterator over the diagonal cartesian indices. @@ -28,86 +28,86 @@ end # to `@view CartesianIndices(a)[diagindices(a)]` but should be # faster because it avoids conversions from linear to cartesian indices. struct DiagCartesianIndices{N} <: AbstractVector{CartesianIndex{N}} - diaglength::Int + diaglength::Int end -function DiagCartesianIndices(axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}}) - # Check the ranges are one-based. - @assert all(isone, first.(axes)) - return DiagCartesianIndices{length(axes)}(minimum(length.(axes))) +function DiagCartesianIndices(axes::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) + # Check the ranges are one-based. + @assert all(isone, first.(axes)) + return DiagCartesianIndices{length(axes)}(minimum(length.(axes))) end -function DiagCartesianIndices(dims::Tuple{Int,Vararg{Int}}) - return DiagCartesianIndices(Base.OneTo.(dims)) +function DiagCartesianIndices(dims::Tuple{Int, Vararg{Int}}) + return DiagCartesianIndices(Base.OneTo.(dims)) end function DiagCartesianIndices(dims::Tuple{}) - return DiagCartesianIndices{0}(0) + return DiagCartesianIndices{0}(0) end function DiagCartesianIndices(a::AbstractArray) - return DiagCartesianIndices(axes(a)) + return DiagCartesianIndices(axes(a)) end Base.size(I::DiagCartesianIndices) = (I.diaglength,) function Base.getindex(I::DiagCartesianIndices{N}, i::Int) where {N} - return CartesianIndex(ntuple(Returns(i), N)) + return CartesianIndex(ntuple(Returns(i), N)) end function checkdiagbounds(::Type{Bool}, a::AbstractArray, i::Integer) - Base.require_one_based_indexing(a) - return i ∈ 1:diaglength(a) + Base.require_one_based_indexing(a) + return i ∈ 1:diaglength(a) end function checkdiagbounds(a::AbstractArray, i::Integer) - checkdiagbounds(Bool, a, i) || throw(BoundsError(a, ntuple(Returns(i), ndims(a)))) - return nothing + checkdiagbounds(Bool, a, i) || throw(BoundsError(a, ntuple(Returns(i), ndims(a)))) + return nothing end # Convert a linear index along the diagonal to the corresponding # CartesianIndex. @inline function diagindex(a::AbstractArray, i::Integer) - @boundscheck checkdiagbounds(a, i) - return CartesianIndex(ntuple(Returns(i), ndims(a))) + @boundscheck checkdiagbounds(a, i) + return CartesianIndex(ntuple(Returns(i), ndims(a))) end function diagindices(a::AbstractArray) - return diagindices(IndexStyle(a), a) + return diagindices(IndexStyle(a), a) end function diagindices(::IndexLinear, a::AbstractArray) - maxdiag = isempty(a) ? 0 : @inbounds LinearIndices(a)[diagindex(a, diaglength(a))] - return 1:diagstride(a):maxdiag + maxdiag = isempty(a) ? 0 : @inbounds LinearIndices(a)[diagindex(a, diaglength(a))] + return 1:diagstride(a):maxdiag end function diagindices(::IndexCartesian, a::AbstractArray) - return DiagCartesianIndices(a) + return DiagCartesianIndices(a) end -function diagindices(a::AbstractArray{<:Any,0}) - return Base.OneTo(1) +function diagindices(a::AbstractArray{<:Any, 0}) + return Base.OneTo(1) end function diagview(a::AbstractArray) - return @view a[diagindices(a)] + return @view a[diagindices(a)] end using LinearAlgebra: Diagonal diagview(a::Diagonal) = a.diag function getdiagindex(a::AbstractArray, i::Integer) - return diagview(a)[i] + return diagview(a)[i] end function setdiagindex!(a::AbstractArray, v, i::Integer) - diagview(a)[i] = v - return a + diagview(a)[i] = v + return a end function getdiagindices(a::AbstractArray, I) - # TODO: Should this be a view? - return @view diagview(a)[I] + # TODO: Should this be a view? + return @view diagview(a)[I] end function getdiagindices(a::AbstractArray, I::Colon) - return diagview(a) + return diagview(a) end function setdiagindices!(a::AbstractArray, v, i::Colon) - diagview(a) .= v - return a + diagview(a) .= v + return a end """ diff --git a/src/diagonalarray/arraylayouts.jl b/src/diagonalarray/arraylayouts.jl index 7f365db..e179c2a 100644 --- a/src/diagonalarray/arraylayouts.jl +++ b/src/diagonalarray/arraylayouts.jl @@ -5,7 +5,7 @@ default_diagonalarraytype(elt::Type) = DiagonalArray{elt} # TODO: Preserve GPU memory! Implement `CuSparseArrayLayout`, `MtlSparseLayout`? function Base.similar( - ::MulAdd{<:AbstractDiagonalLayout,<:AbstractDiagonalLayout}, elt::Type, axes -) - return similar(default_diagonalarraytype(elt), axes) + ::MulAdd{<:AbstractDiagonalLayout, <:AbstractDiagonalLayout}, elt::Type, axes + ) + return similar(default_diagonalarraytype(elt), axes) end diff --git a/src/diagonalarray/delta.jl b/src/diagonalarray/delta.jl index f91e312..77093b6 100644 --- a/src/diagonalarray/delta.jl +++ b/src/diagonalarray/delta.jl @@ -1,104 +1,104 @@ using FillArrays: AbstractFillVector, Ones, OnesVector -const ScaledDelta{T,N,Diag<:AbstractFillVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{ - T,N,Diag,Unstored +const ScaledDelta{T, N, Diag <: AbstractFillVector{T}, Unstored <: AbstractArray{T, N}} = DiagonalArray{ + T, N, Diag, Unstored, } -const ScaledDeltaVector{T,Diag<:AbstractFillVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{ - T,Diag,Unstored +const ScaledDeltaVector{T, Diag <: AbstractFillVector{T}, Unstored <: AbstractVector{T}} = DiagonalVector{ + T, Diag, Unstored, } -const ScaledDeltaMatrix{T,Diag<:AbstractFillVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{ - T,Diag,Unstored +const ScaledDeltaMatrix{T, Diag <: AbstractFillVector{T}, Unstored <: AbstractMatrix{T}} = DiagonalMatrix{ + T, Diag, Unstored, } -const Delta{T,N,Diag<:OnesVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{ - T,N,Diag,Unstored +const Delta{T, N, Diag <: OnesVector{T}, Unstored <: AbstractArray{T, N}} = DiagonalArray{ + T, N, Diag, Unstored, } -const DeltaVector{T,Diag<:OnesVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{ - T,Diag,Unstored +const DeltaVector{T, Diag <: OnesVector{T}, Unstored <: AbstractVector{T}} = DiagonalVector{ + T, Diag, Unstored, } -const DeltaMatrix{T,Diag<:OnesVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{ - T,Diag,Unstored +const DeltaMatrix{T, Diag <: OnesVector{T}, Unstored <: AbstractMatrix{T}} = DiagonalMatrix{ + T, Diag, Unstored, } function Delta{T}( - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} -) where {T} - uniquelens = unique(length, ax) - if !isone(length(uniquelens)) - throw(ArgumentError("All axes must have the same length for Delta.")) - end - return DiagonalArray{T}(Ones{T}(only(uniquelens)), ax) + ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}} + ) where {T} + uniquelens = unique(length, ax) + if !isone(length(uniquelens)) + throw(ArgumentError("All axes must have the same length for Delta.")) + end + return DiagonalArray{T}(Ones{T}(only(uniquelens)), ax) end function Delta{T}( - ax1::AbstractUnitRange{<:Integer}, ax_rest::AbstractUnitRange{<:Integer}... -) where {T} - return Delta{T}((ax1, ax_rest...)) + ax1::AbstractUnitRange{<:Integer}, ax_rest::AbstractUnitRange{<:Integer}... + ) where {T} + return Delta{T}((ax1, ax_rest...)) end -function Delta{T}(sz::Tuple{Integer,Vararg{Integer}}) where {T} - return Delta{T}(map(Base.OneTo, sz)) +function Delta{T}(sz::Tuple{Integer, Vararg{Integer}}) where {T} + return Delta{T}(map(Base.OneTo, sz)) end function Delta{T}(sz1::Integer, sz_rest::Integer...) where {T} - return Delta{T}((sz1, sz_rest...)) + return Delta{T}((sz1, sz_rest...)) end function Delta{T}(ax::Tuple{}) where {T} - return DiagonalArray{T}(Ones{T}(0), ax) + return DiagonalArray{T}(Ones{T}(0), ax) end function delta( - elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} -) - return Delta{elt}(ax) + elt::Type, ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}} + ) + return Delta{elt}(ax) end function δ( - elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} -) - return delta(elt, ax) + elt::Type, ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}} + ) + return delta(elt, ax) end -function delta(ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}) - return delta(Float64, ax) +function delta(ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}) + return delta(Float64, ax) end -function δ(ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}) - return delta(Float64, ax) +function δ(ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}) + return delta(Float64, ax) end function delta( - elt::Type, ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}... -) - return delta(elt, (ax1, axs...)) + elt::Type, ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}... + ) + return delta(elt, (ax1, axs...)) end function δ( - elt::Type, ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}... -) - return delta(elt, (ax1, axs...)) + elt::Type, ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}... + ) + return delta(elt, (ax1, axs...)) end function delta(ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}...) - return delta(Float64, (ax1, axs...)) + return delta(Float64, (ax1, axs...)) end function δ(ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}...) - return delta(Float64, (ax1, axs...)) + return delta(Float64, (ax1, axs...)) end function delta(elt::Type, size::Tuple{Vararg{Int}}) - return Delta{elt}(size) + return Delta{elt}(size) end function δ(elt::Type, size::Tuple{Vararg{Int}}) - return delta(elt, size) + return delta(elt, size) end function delta(elt::Type, size::Int...) - return delta(elt, size) + return delta(elt, size) end function δ(elt::Type, size::Int...) - return delta(elt, size...) + return delta(elt, size...) end function delta(size::Tuple{Vararg{Int}}) - return delta(Float64, size) + return delta(Float64, size) end function δ(size::Tuple{Vararg{Int}}) - return delta(size) + return delta(size) end function delta(size::Int...) - return delta(size) + return delta(size) end function δ(size::Int...) - return delta(size...) + return delta(size...) end diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index f0c85ac..8a608c4 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -1,52 +1,52 @@ using FillArrays: Zeros using SparseArraysBase: Unstored, unstored -diaglength_from_shape(sz::Tuple{Integer,Vararg{Integer}}) = minimum(sz) +diaglength_from_shape(sz::Tuple{Integer, Vararg{Integer}}) = minimum(sz) function diaglength_from_shape( - sz::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} -) - return minimum(length, sz) + sz::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}} + ) + return minimum(length, sz) end diaglength_from_shape(sz::Tuple{}) = 1 -struct DiagonalArray{T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} <: - AbstractDiagonalArray{T,N} - diag::D - unstored::U - function DiagonalArray{T,N,D,U}( - diag::AbstractVector, unstored::Unstored - ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - length(diag) == diaglength_from_shape(size(unstored)) || - throw(ArgumentError("Length of diagonals doesn't match dimensions")) - return new{T,N,D,U}(diag, parent(unstored)) - end +struct DiagonalArray{T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} <: + AbstractDiagonalArray{T, N} + diag::D + unstored::U + function DiagonalArray{T, N, D, U}( + diag::AbstractVector, unstored::Unstored + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + length(diag) == diaglength_from_shape(size(unstored)) || + throw(ArgumentError("Length of diagonals doesn't match dimensions")) + return new{T, N, D, U}(diag, parent(unstored)) + end end SparseArraysBase.unstored(a::DiagonalArray) = a.unstored Base.size(a::DiagonalArray) = size(unstored(a)) Base.axes(a::DiagonalArray) = axes(unstored(a)) -function DiagonalArray{T,N,D}( - diag::D, unstored::Unstored{T,N,U} -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(diag, unstored) +function DiagonalArray{T, N, D}( + diag::D, unstored::Unstored{T, N, U} + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + return DiagonalArray{T, N, D, U}(diag, unstored) end -function DiagonalArray{T,N}( - diag::D, unstored::Unstored{T,N} -) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D}(diag, unstored) +function DiagonalArray{T, N}( + diag::D, unstored::Unstored{T, N} + ) where {T, N, D <: AbstractVector{T}} + return DiagonalArray{T, N, D}(diag, unstored) end -function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N} - return DiagonalArray{T,N}(diag, unstored) +function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T, N}) where {T, N} + return DiagonalArray{T, N}(diag, unstored) end function DiagonalArray(diag::AbstractVector{T}, unstored::Unstored{T}) where {T} - return DiagonalArray{T}(diag, unstored) + return DiagonalArray{T}(diag, unstored) end function DiagonalArray(::UndefInitializer, unstored::Unstored) - return DiagonalArray( - Vector{eltype(unstored)}(undef, diaglength_from_shape(size(unstored))), unstored - ) + return DiagonalArray( + Vector{eltype(unstored)}(undef, diaglength_from_shape(size(unstored))), unstored + ) end # Indicate we will construct an array just from the shape, @@ -57,226 +57,226 @@ struct ShapeInitializer end # This is used to create custom constructors for arrays, # in this case a generic constructor of a vector from a length. function construct(vect::Type{<:AbstractVector}, ::ShapeInitializer, len::Integer) - if applicable(vect, len) - return vect(len) - elseif applicable(vect, (Base.OneTo(len),)) - return vect((Base.OneTo(len),)) - else - error(lazy"Can't construct $(vect) from length.") - end + if applicable(vect, len) + return vect(len) + elseif applicable(vect, (Base.OneTo(len),)) + return vect((Base.OneTo(len),)) + else + error(lazy"Can't construct $(vect) from length.") + end end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, unstored::Unstored{T,N,U} -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}( - construct(D, init, diaglength_from_shape(axes(unstored))), unstored - ) +function DiagonalArray{T, N, D, U}( + init::ShapeInitializer, unstored::Unstored{T, N, U} + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + return DiagonalArray{T, N, D, U}( + construct(D, init, diaglength_from_shape(axes(unstored))), unstored + ) end -function DiagonalArray{T,N,D}( - init::ShapeInitializer, unstored::Unstored{T,N,U} -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, unstored) +function DiagonalArray{T, N, D}( + init::ShapeInitializer, unstored::Unstored{T, N, U} + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + return DiagonalArray{T, N, D, U}(init, unstored) end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. # These versions use the default unstored type `Zeros{T,N}`. -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, Unstored(U(ax))) -end -function DiagonalArray{T,N,D}( - init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} -) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D}(init, Unstored(Zeros{T,N}(ax))) -end -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, ax) -end -function DiagonalArray{T,N,D}( - init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... -) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D}(init, ax) -end -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, Base.OneTo.(sz)) -end -function DiagonalArray{T,N,D}( - init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} -) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D}(init, Base.OneTo.(sz)) -end -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, sz1::Integer, sz_rest::Integer... -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, (sz1, sz_rest...)) -end -function DiagonalArray{T,N,D}( - init::ShapeInitializer, sz1::Integer, sz_rest::Integer... -) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D}(init, (sz1, sz_rest...)) +function DiagonalArray{T, N, D, U}( + init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + return DiagonalArray{T, N, D, U}(init, Unstored(U(ax))) +end +function DiagonalArray{T, N, D}( + init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} + ) where {T, N, D <: AbstractVector{T}} + return DiagonalArray{T, N, D}(init, Unstored(Zeros{T, N}(ax))) +end +function DiagonalArray{T, N, D, U}( + init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + return DiagonalArray{T, N, D, U}(init, ax) +end +function DiagonalArray{T, N, D}( + init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... + ) where {T, N, D <: AbstractVector{T}} + return DiagonalArray{T, N, D}(init, ax) +end +function DiagonalArray{T, N, D, U}( + init::ShapeInitializer, sz::Tuple{Integer, Vararg{Integer}} + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + return DiagonalArray{T, N, D, U}(init, Base.OneTo.(sz)) +end +function DiagonalArray{T, N, D}( + init::ShapeInitializer, sz::Tuple{Integer, Vararg{Integer}} + ) where {T, N, D <: AbstractVector{T}} + return DiagonalArray{T, N, D}(init, Base.OneTo.(sz)) +end +function DiagonalArray{T, N, D, U}( + init::ShapeInitializer, sz1::Integer, sz_rest::Integer... + ) where {T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} + return DiagonalArray{T, N, D, U}(init, (sz1, sz_rest...)) +end +function DiagonalArray{T, N, D}( + init::ShapeInitializer, sz1::Integer, sz_rest::Integer... + ) where {T, N, D <: AbstractVector{T}} + return DiagonalArray{T, N, D}(init, (sz1, sz_rest...)) end # Constructor from diagonal entries accepting axes. -function DiagonalArray{T,N}( - diag::AbstractVector, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, -) where {T,N} - N == length(ax) || throw(ArgumentError("Wrong number of axes")) - return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(ax))) -end -function DiagonalArray{T,N}( - diag::AbstractVector, - ax1::AbstractUnitRange{<:Integer}, - axs::AbstractUnitRange{<:Integer}..., -) where {T,N} - return DiagonalArray{T,N}(diag, (ax1, axs...)) +function DiagonalArray{T, N}( + diag::AbstractVector, + ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}, + ) where {T, N} + N == length(ax) || throw(ArgumentError("Wrong number of axes")) + return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(ax))) +end +function DiagonalArray{T, N}( + diag::AbstractVector, + ax1::AbstractUnitRange{<:Integer}, + axs::AbstractUnitRange{<:Integer}..., + ) where {T, N} + return DiagonalArray{T, N}(diag, (ax1, axs...)) end function DiagonalArray{T}( - diag::AbstractVector, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, -) where {T} - return DiagonalArray{T,length(ax)}(diag, ax) + diag::AbstractVector, + ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}, + ) where {T} + return DiagonalArray{T, length(ax)}(diag, ax) end function DiagonalArray{T}( - diag::AbstractVector, - ax1::AbstractUnitRange{<:Integer}, - axs::AbstractUnitRange{<:Integer}..., -) where {T} - return DiagonalArray{T}(diag, (ax1, axs...)) + diag::AbstractVector, + ax1::AbstractUnitRange{<:Integer}, + axs::AbstractUnitRange{<:Integer}..., + ) where {T} + return DiagonalArray{T}(diag, (ax1, axs...)) end function DiagonalArray( - diag::AbstractVector{T}, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, -) where {T} - return DiagonalArray{T,length(ax)}(diag, ax) + diag::AbstractVector{T}, + ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}, + ) where {T} + return DiagonalArray{T, length(ax)}(diag, ax) end function DiagonalArray( - diag::AbstractVector, - ax1::AbstractUnitRange{<:Integer}, - axs::AbstractUnitRange{<:Integer}..., -) - return DiagonalArray(diag, (ax1, axs...)) + diag::AbstractVector, + ax1::AbstractUnitRange{<:Integer}, + axs::AbstractUnitRange{<:Integer}..., + ) + return DiagonalArray(diag, (ax1, axs...)) end # undef constructors accepting axes. -function DiagonalArray{T,N}( - ::UndefInitializer, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, -) where {T,N} - return DiagonalArray{T,N}(Vector{T}(undef, minimum(length, ax)), ax) -end -function DiagonalArray{T,N}( - ::UndefInitializer, - ax1::AbstractUnitRange{<:Integer}, - axs::AbstractUnitRange{<:Integer}..., -) where {T,N} - return DiagonalArray{T,N}(undef, (ax1, axs...)) +function DiagonalArray{T, N}( + ::UndefInitializer, + ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}, + ) where {T, N} + return DiagonalArray{T, N}(Vector{T}(undef, minimum(length, ax)), ax) +end +function DiagonalArray{T, N}( + ::UndefInitializer, + ax1::AbstractUnitRange{<:Integer}, + axs::AbstractUnitRange{<:Integer}..., + ) where {T, N} + return DiagonalArray{T, N}(undef, (ax1, axs...)) end function DiagonalArray{T}( - ::UndefInitializer, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, -) where {T} - return DiagonalArray{T,length(ax)}(undef, ax) + ::UndefInitializer, + ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}, + ) where {T} + return DiagonalArray{T, length(ax)}(undef, ax) end function DiagonalArray{T}( - ::UndefInitializer, - ax1::AbstractUnitRange{<:Integer}, - axs::AbstractUnitRange{<:Integer}..., -) where {T} - return DiagonalArray{T}(undef, (ax1, axs...)) + ::UndefInitializer, + ax1::AbstractUnitRange{<:Integer}, + axs::AbstractUnitRange{<:Integer}..., + ) where {T} + return DiagonalArray{T}(undef, (ax1, axs...)) end -function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N} - return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(dims))) +function DiagonalArray{T, N}(diag::AbstractVector, dims::Dims{N}) where {T, N} + return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(dims))) end -function DiagonalArray{T,N}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(diag, dims) +function DiagonalArray{T, N}(diag::AbstractVector, dims::Vararg{Int, N}) where {T, N} + return DiagonalArray{T, N}(diag, dims) end -function DiagonalArray{T}(diag::AbstractVector, dims::Dims{N}) where {T,N} - return DiagonalArray{T,N}(diag, dims) +function DiagonalArray{T}(diag::AbstractVector, dims::Dims{N}) where {T, N} + return DiagonalArray{T, N}(diag, dims) end -function DiagonalArray{T}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(diag, dims) +function DiagonalArray{T}(diag::AbstractVector, dims::Vararg{Int, N}) where {T, N} + return DiagonalArray{T, N}(diag, dims) end -function DiagonalArray{<:Any,N}(diag::AbstractVector{T}, dims::Dims{N}) where {T,N} - return DiagonalArray{T,N}(diag, dims) +function DiagonalArray{<:Any, N}(diag::AbstractVector{T}, dims::Dims{N}) where {T, N} + return DiagonalArray{T, N}(diag, dims) end -function DiagonalArray{<:Any,N}(diag::AbstractVector{T}, dims::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(diag, dims) +function DiagonalArray{<:Any, N}(diag::AbstractVector{T}, dims::Vararg{Int, N}) where {T, N} + return DiagonalArray{T, N}(diag, dims) end -function DiagonalArray{<:Any,N}( - diag::AbstractVector{T}, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, -) where {T,N} - return DiagonalArray{T,N}(diag, ax) +function DiagonalArray{<:Any, N}( + diag::AbstractVector{T}, + ax::Tuple{AbstractUnitRange{<:Integer}, Vararg{AbstractUnitRange{<:Integer}}}, + ) where {T, N} + return DiagonalArray{T, N}(diag, ax) end -function DiagonalArray(diag::AbstractVector{T}, dims::Dims{N}) where {T,N} - return DiagonalArray{T,N}(diag, dims) +function DiagonalArray(diag::AbstractVector{T}, dims::Dims{N}) where {T, N} + return DiagonalArray{T, N}(diag, dims) end -function DiagonalArray(diag::AbstractVector{T}, dims::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(diag, dims) +function DiagonalArray(diag::AbstractVector{T}, dims::Vararg{Int, N}) where {T, N} + return DiagonalArray{T, N}(diag, dims) end # Infer size from diagonal -function DiagonalArray{T,N}(diag::AbstractVector) where {T,N} - return DiagonalArray{T,N}(diag, ntuple(Returns(length(diag)), N)) +function DiagonalArray{T, N}(diag::AbstractVector) where {T, N} + return DiagonalArray{T, N}(diag, ntuple(Returns(length(diag)), N)) end -function DiagonalArray{<:Any,N}(diag::AbstractVector{T}) where {T,N} - return DiagonalArray{T,N}(diag) +function DiagonalArray{<:Any, N}(diag::AbstractVector{T}) where {T, N} + return DiagonalArray{T, N}(diag) end # undef -function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} - return DiagonalArray{T,N}(Vector{T}(undef, diaglength_from_shape(dims)), dims) +function DiagonalArray{T, N}(::UndefInitializer, dims::Dims{N}) where {T, N} + return DiagonalArray{T, N}(Vector{T}(undef, diaglength_from_shape(dims)), dims) end -function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(undef, dims) +function DiagonalArray{T, N}(::UndefInitializer, dims::Vararg{Int, N}) where {T, N} + return DiagonalArray{T, N}(undef, dims) end -function DiagonalArray{T}(::UndefInitializer, dims::Dims{N}) where {T,N} - return DiagonalArray{T,N}(undef, dims) +function DiagonalArray{T}(::UndefInitializer, dims::Dims{N}) where {T, N} + return DiagonalArray{T, N}(undef, dims) end -function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} - return DiagonalArray{T,N}(undef, dims) +function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int, N}) where {T, N} + return DiagonalArray{T, N}(undef, dims) end # Axes version function DiagonalArray{T}( - ::UndefInitializer, axes::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}} -) where {T} - return DiagonalArray{T,length(axes)}(undef, length.(axes)) + ::UndefInitializer, axes::Tuple{Base.OneTo{Int}, Vararg{Base.OneTo{Int}}} + ) where {T} + return DiagonalArray{T, length(axes)}(undef, length.(axes)) end function Base.similar(a::DiagonalArray, unstored::Unstored) - return DiagonalArray(undef, unstored) + return DiagonalArray(undef, unstored) end # These definitions are helpful for immutable diagonals # such as FillArrays. for f in [:complex, :copy, :imag, :real] - @eval begin - Base.$f(a::DiagonalArray) = DiagonalArray($f(diagview(a)), axes(a)) - end + @eval begin + Base.$f(a::DiagonalArray) = DiagonalArray($f(diagview(a)), axes(a)) + end end # DiagonalArrays interface. @@ -284,19 +284,19 @@ diagview(a::DiagonalArray) = a.diag # Special case for permutedims that is friendlier for immutable storage. function Base.permutedims(a::DiagonalArray, perm) - ((ndims(a) == length(perm)) && isperm(perm)) || - throw(ArgumentError("Not a valid permutation")) - ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a)) - # Unlike `permutedims(::Diagonal, perm)`, we copy here. - return DiagonalArray(copy(diagview(a)), ax_perm) + ((ndims(a) == length(perm)) && isperm(perm)) || + throw(ArgumentError("Not a valid permutation")) + ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a)) + # Unlike `permutedims(::Diagonal, perm)`, we copy here. + return DiagonalArray(copy(diagview(a)), ax_perm) end function DerivableInterfaces.permuteddims(a::DiagonalArray, perm) - ((ndims(a) == length(perm)) && isperm(perm)) || - throw(ArgumentError("Not a valid permutation")) - ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a)) - # Unlike `permutedims(::Diagonal, perm)`, we copy here. - return DiagonalArray(diagview(a), ax_perm) + ((ndims(a) == length(perm)) && isperm(perm)) || + throw(ArgumentError("Not a valid permutation")) + ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a)) + # Unlike `permutedims(::Diagonal, perm)`, we copy here. + return DiagonalArray(diagview(a), ax_perm) end # Scalar indexing. @@ -305,75 +305,75 @@ one_based_range(r) = false one_based_range(r::Base.OneTo) = true one_based_range(r::Base.Slice) = true function _diag_axes(a::DiagonalArray, I...) - return map(ntuple(identity, ndims(a))) do d - return Base.axes1(axes(a, d)[I[d]]) - end + return map(ntuple(identity, ndims(a))) do d + return Base.axes1(axes(a, d)[I[d]]) + end end # A view that preserves the diagonal structure. function _view_diag(a::DiagonalArray, I...) - ax = _diag_axes(a, I...) - return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax) + ax = _diag_axes(a, I...) + return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax) end function _view_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...) - ax = _diag_axes(a, I1, Irest...) - return DiagonalArray(view(diagview(a), :), ax) + ax = _diag_axes(a, I1, Irest...) + return DiagonalArray(view(diagview(a), :), ax) end # A slice that preserves the diagonal structure. function _getindex_diag(a::DiagonalArray, I...) - ax = _diag_axes(a, I...) - return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax) + ax = _diag_axes(a, I...) + return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax) end function _getindex_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...) - ax = _diag_axes(a, I1, Irest...) - return DiagonalArray(diagview(a)[:], ax) + ax = _diag_axes(a, I1, Irest...) + return DiagonalArray(diagview(a)[:], ax) end function Base.view(a::DiagonalArray, I...) - I′ = to_indices(a, I) - return if all(one_based_range, I′) - _view_diag(a, I′...) - else - invoke(view, Tuple{AbstractArray,Vararg}, a, I′...) - end + I′ = to_indices(a, I) + return if all(one_based_range, I′) + _view_diag(a, I′...) + else + invoke(view, Tuple{AbstractArray, Vararg}, a, I′...) + end end function Base.getindex(a::DiagonalArray, I::Int...) - return @interface interface(a) a[I...] + return @interface interface(a) a[I...] end function Base.getindex(a::DiagonalArray, I::DiagIndex) - return getdiagindex(a, index(I)) + return getdiagindex(a, index(I)) end function Base.getindex(a::DiagonalArray, I::DiagIndices) - # TODO: Should this be a view? - return @view diagview(a)[indices(I)] + # TODO: Should this be a view? + return @view diagview(a)[indices(I)] end function Base.getindex(a::DiagonalArray, I...) - I′ = to_indices(a, I) - return if all(i -> i isa Real, I′) - # Catch scalar indexing case. - @interface interface(a) a[I...] - elseif all(one_based_range, I′) - _getindex_diag(a, I′...) - else - copy(view(a, I′...)) - end + I′ = to_indices(a, I) + return if all(i -> i isa Real, I′) + # Catch scalar indexing case. + @interface interface(a) a[I...] + elseif all(one_based_range, I′) + _getindex_diag(a, I′...) + else + copy(view(a, I′...)) + end end # Define in order to preserve immutable diagonals such as FillArrays. -function DiagonalArray{T,N}(a::DiagonalArray{T,N}) where {T,N} - # TODO: Should this copy? This matches the design of `LinearAlgebra.Diagonal`: - # https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112 - return a +function DiagonalArray{T, N}(a::DiagonalArray{T, N}) where {T, N} + # TODO: Should this copy? This matches the design of `LinearAlgebra.Diagonal`: + # https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112 + return a end -function DiagonalArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N} - return DiagonalArray{T,N}(diagview(a)) +function DiagonalArray{T, N}(a::DiagonalArray{<:Any, N}) where {T, N} + return DiagonalArray{T, N}(diagview(a)) end function DiagonalArray{T}(a::DiagonalArray) where {T} - return DiagonalArray{T,ndims(a)}(a) + return DiagonalArray{T, ndims(a)}(a) end function DiagonalArray(a::DiagonalArray) - return DiagonalArray{eltype(a),ndims(a)}(a) + return DiagonalArray{eltype(a), ndims(a)}(a) end -function Base.AbstractArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N} - return DiagonalArray{T,N}(a) +function Base.AbstractArray{T, N}(a::DiagonalArray{<:Any, N}) where {T, N} + return DiagonalArray{T, N}(a) end # TODO: These definitions work around this issue: @@ -408,8 +408,8 @@ _broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axe _broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a)) function Broadcast.broadcasted( - ::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T,N,Diag} -) where {F,T,N,Diag<:AbstractFill{T}} - # TODO: Check that `f` preserves zeros? - return DiagonalArray(_broadcasted(f, diagview(a)), axes(a)) + ::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T, N, Diag} + ) where {F, T, N, Diag <: AbstractFill{T}} + # TODO: Check that `f` preserves zeros? + return DiagonalArray(_broadcasted(f, diagview(a)), axes(a)) end diff --git a/src/diagonalarray/diagonalmatrix.jl b/src/diagonalarray/diagonalmatrix.jl index d3a310d..dddb1fa 100644 --- a/src/diagonalarray/diagonalmatrix.jl +++ b/src/diagonalarray/diagonalmatrix.jl @@ -1,5 +1,5 @@ -const DiagonalMatrix{T,Diag<:AbstractVector{T},Unstored<:AbstractMatrix{T}} = DiagonalArray{ - T,2,Diag,Unstored +const DiagonalMatrix{T, Diag <: AbstractVector{T}, Unstored <: AbstractMatrix{T}} = DiagonalArray{ + T, 2, Diag, Unstored, } # LinearAlgebra @@ -7,142 +7,142 @@ const DiagonalMatrix{T,Diag<:AbstractVector{T},Unstored<:AbstractMatrix{T}} = Di using LinearAlgebra: LinearAlgebra function mul_diagviews(a1, a2) - # TODO: Compare that duals are equal, or define a function to overload. - dual(axes(a1, 2)) == axes(a2, 1) || throw( - DimensionMismatch( - lazy"Incompatible dimensions for multiplication: $(axes(a1)) and $(axes(a2))" - ), - ) - d1 = diagview(a1) - d2 = diagview(a2) - l = min(length(d1), length(d2)) - d1′ = view(d1, Base.OneTo(l)) - d2′ = view(d2, Base.OneTo(l)) - return (d1′, d2′) + # TODO: Compare that duals are equal, or define a function to overload. + dual(axes(a1, 2)) == axes(a2, 1) || throw( + DimensionMismatch( + lazy"Incompatible dimensions for multiplication: $(axes(a1)) and $(axes(a2))" + ), + ) + d1 = diagview(a1) + d2 = diagview(a2) + l = min(length(d1), length(d2)) + d1′ = view(d1, Base.OneTo(l)) + d2′ = view(d2, Base.OneTo(l)) + return (d1′, d2′) end function mul!_diagviews(a_dest, a1, a2) - axes(a_dest, 1) == axes(a1, 1) || throw( - DimensionMismatch( - lazy"Incompatible dimensions for multiplication: $(axes(a_dest)) and $(axes(a1))" - ), - ) - axes(a_dest, 2) == axes(a2, 2) || throw( - DimensionMismatch( - lazy"Incompatible dimensions for multiplication: $(axes(a_dest)) and $(axes(a2))" - ), - ) - d_dest = diagview(a_dest) - d1, d2 = mul_diagviews(a1, a2) - return d_dest, d1, d2 + axes(a_dest, 1) == axes(a1, 1) || throw( + DimensionMismatch( + lazy"Incompatible dimensions for multiplication: $(axes(a_dest)) and $(axes(a1))" + ), + ) + axes(a_dest, 2) == axes(a2, 2) || throw( + DimensionMismatch( + lazy"Incompatible dimensions for multiplication: $(axes(a_dest)) and $(axes(a2))" + ), + ) + d_dest = diagview(a_dest) + d1, d2 = mul_diagviews(a1, a2) + return d_dest, d1, d2 end function Base.:*(a1::DiagonalMatrix, a2::DiagonalMatrix) - d1, d2 = mul_diagviews(a1, a2) - # TODO: Handle the rack-deficient case, for example: - # δ(3, 2) * δ(2, 3) - # Maybe pack the diagonal with zeros or allow rank-deficient DiagonalArrays? - return DiagonalMatrix(d1 .* d2, (axes(a1, 1), axes(a2, 2))) + d1, d2 = mul_diagviews(a1, a2) + # TODO: Handle the rack-deficient case, for example: + # δ(3, 2) * δ(2, 3) + # Maybe pack the diagonal with zeros or allow rank-deficient DiagonalArrays? + return DiagonalMatrix(d1 .* d2, (axes(a1, 1), axes(a2, 2))) end function LinearAlgebra.mul!(a_dest::DiagonalMatrix, a1::DiagonalMatrix, a2::DiagonalMatrix) - d_dest, d1, d2 = mul!_diagviews(a_dest, a1, a2) - # TODO: Handle the rack-deficient case. - d_dest .= d1 .* d2 - return a_dest + d_dest, d1, d2 = mul!_diagviews(a_dest, a1, a2) + # TODO: Handle the rack-deficient case. + d_dest .= d1 .* d2 + return a_dest end function LinearAlgebra.mul!( - a_dest::DiagonalMatrix, a1::DiagonalMatrix, a2::DiagonalMatrix, α::Number, β::Number -) - d_dest, d1, d2 = mul!_diagviews(a_dest, a1, a2) - # TODO: Handle the rack-deficient case. - d_dest .= d1 .* d2 .* α .+ d_dest .* β - return a_dest + a_dest::DiagonalMatrix, a1::DiagonalMatrix, a2::DiagonalMatrix, α::Number, β::Number + ) + d_dest, d1, d2 = mul!_diagviews(a_dest, a1, a2) + # TODO: Handle the rack-deficient case. + d_dest .= d1 .* d2 .* α .+ d_dest .* β + return a_dest end # Adapted from https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L866-L928. function LinearAlgebra.tr(a::DiagonalMatrix) - checksquare(a) - # TODO: Define as `sum(tr, diagview(a))` like LinearAlgebra.jl? - return sum(diagview(a)) + checksquare(a) + # TODO: Define as `sum(tr, diagview(a))` like LinearAlgebra.jl? + return sum(diagview(a)) end # TODO: Special case for FillArrays diagonals. function LinearAlgebra.det(a::DiagonalMatrix) - checksquare(a) - # TODO: Define as `prod(det, diagview(a))` like LinearAlgebra.jl? - return prod(diagview(a)) + checksquare(a) + # TODO: Define as `prod(det, diagview(a))` like LinearAlgebra.jl? + return prod(diagview(a)) end # TODO: Special case for FillArrays diagonals. function LinearAlgebra.logabsdet(a::DiagonalMatrix) - checksquare(a) - return mapreduce(((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2), diagview(a)) do x - return (log(abs(x)), sign(x)) - end + checksquare(a) + return mapreduce(((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2), diagview(a)) do x + return (log(abs(x)), sign(x)) + end end # TODO: Special case for FillArrays diagonals. function LinearAlgebra.logdet(a::DiagonalMatrix{<:Complex}) - checksquare(a) - z = sum(log, diagview(a)) - return complex(real(z), rem2pi(imag(z), RoundNearest)) + checksquare(a) + z = sum(log, diagview(a)) + return complex(real(z), rem2pi(imag(z), RoundNearest)) end # Matrix functions for f in [ - :exp, - :cis, - :log, - :sqrt, - :cos, - :sin, - :tan, - :csc, - :sec, - :cot, - :cosh, - :sinh, - :tanh, - :csch, - :sech, - :coth, - :acos, - :asin, - :atan, - :acsc, - :asec, - :acot, - :acosh, - :asinh, - :atanh, - :acsch, - :asech, - :acoth, -] - @eval begin - function Base.$f(a::DiagonalMatrix) - checksquare(a) - return DiagonalMatrix(_broadcast($f, diagview(a)), axes(a)) + :exp, + :cis, + :log, + :sqrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, + ] + @eval begin + function Base.$f(a::DiagonalMatrix) + checksquare(a) + return DiagonalMatrix(_broadcast($f, diagview(a)), axes(a)) + end end - end end # Cube root of a real-valued diagonal matrix function Base.cbrt(a::DiagonalMatrix{<:Real}) - checksquare(a) - return DiagonalMatrix(_broadcast(cbrt, diagview(a)), axes(a)) + checksquare(a) + return DiagonalMatrix(_broadcast(cbrt, diagview(a)), axes(a)) end function LinearAlgebra.inv(a::DiagonalMatrix) - checksquare(a) - # `DiagonalArrays._broadcast` works around issues like https://github.com/JuliaArrays/FillArrays.jl/issues/416 - # when the diagonal is a FillArray or similar lazy array. - d⁻¹ = _broadcast(inv, diagview(a)) - any(isinf, d⁻¹) && error("Singular Exception") - return DiagonalMatrix(d⁻¹, axes(a)) + checksquare(a) + # `DiagonalArrays._broadcast` works around issues like https://github.com/JuliaArrays/FillArrays.jl/issues/416 + # when the diagonal is a FillArray or similar lazy array. + d⁻¹ = _broadcast(inv, diagview(a)) + any(isinf, d⁻¹) && error("Singular Exception") + return DiagonalMatrix(d⁻¹, axes(a)) end # TODO: Support `atol` and `rtol` keyword arguments: # https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.pinv using LinearAlgebra: pinv function LinearAlgebra.pinv(a::DiagonalMatrix) - checksquare(a) - return DiagonalMatrix(_broadcast(pinv, diagview(a)), axes(a)) + checksquare(a) + return DiagonalMatrix(_broadcast(pinv, diagview(a)), axes(a)) end diff --git a/src/diagonalarray/diagonalvector.jl b/src/diagonalarray/diagonalvector.jl index ec3cde8..e1d51ab 100644 --- a/src/diagonalarray/diagonalvector.jl +++ b/src/diagonalarray/diagonalvector.jl @@ -1,7 +1,7 @@ -const DiagonalVector{T,Diag<:AbstractVector{T},Unstored<:AbstractVector{T}} = DiagonalArray{ - T,1,Diag,Unstored +const DiagonalVector{T, Diag <: AbstractVector{T}, Unstored <: AbstractVector{T}} = DiagonalArray{ + T, 1, Diag, Unstored, } function DiagonalVector(diag::AbstractVector) - return DiagonalArray{<:Any,1}(diag) + return DiagonalArray{<:Any, 1}(diag) end diff --git a/src/dual.jl b/src/dual.jl index 36b6ee5..cf2442a 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -6,6 +6,6 @@ issquare(a::AbstractMatrix) = (axes(a, 1) == dual(axes(a, 2))) # codomain is the dual of the domain. # Returns the codomain if the check passes. function checksquare(a::AbstractMatrix) - issquare(a) || throw(DimensionMismatch(lazy"matrix is not square: axes are $(axes(a))")) - return axes(a, 1) + issquare(a) || throw(DimensionMismatch(lazy"matrix is not square: axes are $(axes(a))")) + return axes(a, 1) end diff --git a/test/runtests.jl b/test/runtests.jl index e2c9599..39c332d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,58 +6,60 @@ using Suppressor: Suppressor const pat = r"(?:--group=)(\w+)" arg_id = findfirst(contains(pat), ARGS) const GROUP = uppercase( - if isnothing(arg_id) - get(ENV, "GROUP", "ALL") - else - only(match(pat, ARGS[arg_id]).captures) - end, + if isnothing(arg_id) + get(ENV, "GROUP", "ALL") + else + only(match(pat, ARGS[arg_id]).captures) + end, ) "match files of the form `test_*.jl`, but exclude `*setup*.jl`" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") end "match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" function isexamplefile(fn) - return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") + return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @time begin - # tests in groups based on folder structure - for testgroup in filter(isdir, readdir(@__DIR__)) - if GROUP == "ALL" || GROUP == uppercase(testgroup) - for file in filter(istestfile, readdir(joinpath(@__DIR__, testgroup); join=true)) - @eval @safetestset $file begin - include($file) + # tests in groups based on folder structure + for testgroup in filter(isdir, readdir(@__DIR__)) + if GROUP == "ALL" || GROUP == uppercase(testgroup) + for file in filter(istestfile, readdir(joinpath(@__DIR__, testgroup); join = true)) + @eval @safetestset $file begin + include($file) + end + end end - end end - end - # single files in top folder - for file in filter(istestfile, readdir(@__DIR__)) - (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion - @eval @safetestset $file begin - include($file) + # single files in top folder + for file in filter(istestfile, readdir(@__DIR__)) + (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion + @eval @safetestset $file begin + include($file) + end end - end - # test examples - examplepath = joinpath(@__DIR__, "..", "examples") - for (root, _, files) in walkdir(examplepath) - contains(chopprefix(root, @__DIR__), "setup") && continue - for file in filter(isexamplefile, files) - filename = joinpath(root, file) - @eval begin - @safetestset $file begin - $(Expr( - :macrocall, - GlobalRef(Suppressor, Symbol("@suppress")), - LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), - )) + # test examples + examplepath = joinpath(@__DIR__, "..", "examples") + for (root, _, files) in walkdir(examplepath) + contains(chopprefix(root, @__DIR__), "setup") && continue + for file in filter(isexamplefile, files) + filename = joinpath(root, file) + @eval begin + @safetestset $file begin + $( + Expr( + :macrocall, + GlobalRef(Suppressor, Symbol("@suppress")), + LineNumberNode(@__LINE__, @__FILE__), + :(include($filename)), + ) + ) + end + end end - end end - end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 15334c6..ae8789b 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,6 +3,6 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - # TODO: fix ambiguities - Aqua.test_all(DiagonalArrays; ambiguities=false) + # TODO: fix ambiguities + Aqua.test_all(DiagonalArrays; ambiguities = false) end diff --git a/test/test_basics.jl b/test/test_basics.jl index c40a913..2ecd323 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,497 +1,497 @@ using DerivableInterfaces: permuteddims using DiagonalArrays: - DiagonalArrays, - ShapeInitializer, - Delta, - DeltaMatrix, - DiagonalArray, - DiagonalMatrix, - ScaledDelta, - ScaledDeltaMatrix, - Unstored, - δ, - delta, - diagindices, - diaglength, - diagonal, - diagonaltype, - diagview, - getdiagindices + DiagonalArrays, + ShapeInitializer, + Delta, + DeltaMatrix, + DiagonalArray, + DiagonalMatrix, + ScaledDelta, + ScaledDeltaMatrix, + Unstored, + δ, + delta, + diagindices, + diaglength, + diagonal, + diagonaltype, + diagview, + getdiagindices using FillArrays: Fill, Ones, Zeros using LinearAlgebra: - Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr + Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength using Test: @test, @test_throws, @testset, @test_broken, @inferred @testset "Test DiagonalArrays" begin - @testset "DiagonalArray (eltype=$elt)" for elt in ( - Float32, Float64, Complex{Float32}, Complex{Float64} - ) - @testset "Basics" begin - a = fill(one(elt), 2, 3) - @test diaglength(a) == 2 - a = fill(one(elt)) - @test diaglength(a) == 1 - end - @testset "diagindices" begin - a = randn(elt, ()) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:1 - @test isempty(diagindices(IndexCartesian(), a)) - - for a in ( - randn(elt, (0,)), - randn(elt, (0, 0)), - randn(elt, (0, 3)), - randn(elt, (3, 0)), - randn(elt, (0, 0, 0)), - randn(elt, (3, 3, 0)), - ) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:0 - @test isempty(diagindices(IndexCartesian(), a)) - end - - a = randn(elt, (3,)) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:3 - @test diagindices(IndexCartesian(), a) == CartesianIndex.(1:3) - - a = randn(elt, (4,)) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:4 - @test diagindices(IndexCartesian(), a) == CartesianIndex.(1:4) - - for a in (randn(elt, (3, 3)), randn(elt, (3, 4))) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:4:9 - @test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3)) - end - - a = randn(elt, (4, 3)) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:5:11 - @test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3)) - - for a in (randn(elt, (3, 3, 3)), randn(elt, (3, 3, 4))) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:13:27 - @test diagindices(IndexCartesian(), a) == - CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3)) - end - - a = randn(elt, (3, 4, 3)) - @test diagindices(a) == diagindices(IndexLinear(), a) == 1:16:33 - @test diagindices(IndexCartesian(), a) == - CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3)) - end - @testset "DiagonalArray constructors" begin - v = randn(elt, 2) - @test DiagonalArray(v, 2, 2) ≡ - DiagonalArray(v, (2, 2)) ≡ - DiagonalArray(v, Base.OneTo(2), Base.OneTo(2)) ≡ - DiagonalArray(v, (Base.OneTo(2), Base.OneTo(2))) ≡ - DiagonalArray{elt}(v, 2, 2) ≡ - DiagonalArray{elt}(v, (2, 2)) ≡ - DiagonalArray{elt}(v, Base.OneTo(2), Base.OneTo(2)) ≡ - DiagonalArray{elt}(v, (Base.OneTo(2), Base.OneTo(2))) ≡ - DiagonalArray{elt,2}(v, 2, 2) ≡ - DiagonalArray{elt,2}(v, (2, 2)) ≡ - DiagonalArray{elt,2}(v, Base.OneTo(2), Base.OneTo(2)) ≡ - DiagonalArray{elt,2}(v, (Base.OneTo(2), Base.OneTo(2))) - @test size(DiagonalArray{elt}(undef, 2, 2)) ≡ - size(DiagonalArray{elt}(undef, (2, 2))) ≡ - size(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ - size(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) ≡ - size(DiagonalArray{elt,2}(undef, 2, 2)) ≡ - size(DiagonalArray{elt,2}(undef, (2, 2))) ≡ - size(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ - size(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2)))) - @test elt ≡ - eltype(DiagonalArray{elt}(undef, 2, 2)) ≡ - eltype(DiagonalArray{elt}(undef, (2, 2))) ≡ - eltype(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ - eltype(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) ≡ - eltype(DiagonalMatrix{elt}(undef, 2, 2)) ≡ - eltype(DiagonalMatrix{elt}(undef, (2, 2))) ≡ - eltype(DiagonalMatrix{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ - eltype(DiagonalMatrix{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) - - # Special constructors for immutable diagonal. - init = ShapeInitializer() - U = Zeros{UInt32,2,Tuple{Base.OneTo{Int},Base.OneTo{Int}}} - @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, (2, 2)) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, 2, 2) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Base.OneTo.((2, 2))) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Base.OneTo.((2, 2))...) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, (2, 2)) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, 2, 2) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Unstored(Zeros{UInt32}(2, 2))) - - init = ShapeInitializer() - @test DiagonalMatrix(Ones{elt}(2)) ≡ - DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, Base.OneTo.((2, 2))) ≡ - DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}( - init, Base.OneTo.((2, 2))... - ) ≡ - DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, (2, 2)) ≡ - DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, 2, 2) ≡ - DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}( - init, Unstored(Zeros{elt}(2, 2)) + @testset "DiagonalArray (eltype=$elt)" for elt in ( + Float32, Float64, Complex{Float32}, Complex{Float64}, ) - - init = ShapeInitializer() - @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, Base.OneTo.((2, 2))) - @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}( - init, Base.OneTo.((2, 2))... - ) - @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, (2, 2)) - @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, 2, 2) - - # 0-dim constructors - v = randn(elt, 1) - @test DiagonalArray(v) ≡ - DiagonalArray(v, ()) ≡ - DiagonalArray{elt}(v) ≡ - DiagonalArray{elt}(v, ()) ≡ - DiagonalArray{elt,0}(v) ≡ - DiagonalArray{elt,0}(v, ()) - @test size(DiagonalArray{elt}(undef)) ≡ - size(DiagonalArray{elt}(undef, ())) ≡ - size(DiagonalArray{elt,0}(undef)) ≡ - size(DiagonalArray{elt,0}(undef, ())) - @test elt ≡ - eltype(DiagonalArray{elt}(undef)) ≡ - eltype(DiagonalArray{elt}(undef, ())) ≡ - eltype(DiagonalArray{elt,0}(undef)) ≡ - eltype(DiagonalArray{elt,0}(undef, ())) - - # Special constructors for immutable diagonal. - init = ShapeInitializer() - @test DiagonalArray{<:Any,0}(Base.OneTo(UInt32(1))) ≡ - DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, ()) ≡ - DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) ≡ - DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}())) - end - @testset "0-dim operations" begin - diag = randn(elt, 1) - a = DiagonalArray(diag) - @test a[] == diag[1] - a[] = 2 - @test a[] == 2 - end - @testset "Conversion" begin - a = DiagonalMatrix(randn(elt, 2)) - @test DiagonalMatrix{elt}(a) ≡ a - @test DiagonalMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) - @test DiagonalArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) - @test DiagonalArray(a) ≡ a - @test AbstractMatrix{elt}(a) ≡ a - @test AbstractMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) - @test AbstractArray{elt}(a) ≡ a - @test AbstractArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) - end - @testset "Slicing" begin - # Slicing that preserves the diagonal structure. - a = DiagonalMatrix(randn(elt, 3)) - b = @view a[:, :] - @test b isa DiagonalMatrix{elt,<:SubArray{elt,1}} - @test diagview(b) ≡ view(diagview(a), :) - - a = DiagonalMatrix(randn(elt, 3)) - b = @view a[Base.OneTo(2), Base.OneTo(2)] - @test b isa DiagonalMatrix{elt,<:SubArray{elt,1}} - @test diagview(b) ≡ view(diagview(a), Base.OneTo(2)) - - a = DiagonalMatrix(randn(elt, 3)) - b = a[:, :] - @test typeof(b) == typeof(a) - @test diagview(b) == diagview(a) - - a = DiagonalMatrix(randn(elt, 3)) - b = a[Base.OneTo(2), Base.OneTo(2)] - @test typeof(b) == typeof(a) - @test diagview(b) == diagview(a)[Base.OneTo(2)] - - # Slicing that doesn't preserve the diagonal structure. - a = DiagonalMatrix(randn(elt, 3)) - b = @view a[2:3, 2:3] - @test b isa SubArray - @test b == Matrix(a)[2:3, 2:3] - - a = DiagonalMatrix(randn(elt, 3)) - b = a[2:3, 2:3] - @test b isa SparseMatrixDOK{elt} - @test b == Matrix(a)[2:3, 2:3] - @test storedlength(b) == 2 - end - @testset "permutedims" begin - a = DiagonalArray(randn(elt, 2), (2, 3, 4)) - b = permutedims(a, (3, 1, 2)) - @test diagview(b) == diagview(a) - @test diagview(b) ≢ diagview(a) - @test size(b) === (4, 2, 3) - end - @testset "DerivableInterfaces.permuteddims" begin - a = DiagonalArray(randn(elt, 2), (2, 3, 4)) - b = permuteddims(a, (3, 1, 2)) - @test diagview(b) ≡ diagview(a) - @test size(b) === (4, 2, 3) - end - @testset "Broadcasting" begin - a = DiagonalArray(randn(elt, 2), (2, 3)) - b = DiagonalArray(randn(elt, 2), (2, 3)) - c = a .+ 2 .* b - @test c ≈ Array(a) + 2 * Array(b) - # Non-zero-preserving functions not supported yet. - @test_broken a .+ 2 - - c = DiagonalArray{elt}(undef, (2, 3)) - c .= a .+ 2 .* b - @test c ≈ Array(a) + 2 * Array(b) - - # Non-zero-preserving functions not supported yet. - c = DiagonalArray{elt}(undef, (2, 3)) - @test_broken c .= a .+ 2 - - a_ones = DiagonalMatrix(Ones{elt}(2)) - a_zeros = DiagonalMatrix(Zeros{elt}(2)) - @test identity.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) - @test identity.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) - @test complex.(a_ones) ≡ DiagonalMatrix(Ones{complex(elt)}(2)) - @test complex.(a_zeros) ≡ DiagonalMatrix(Zeros{complex(elt)}(2)) - @test Float32.(a_ones) ≡ DiagonalMatrix(Ones{Float32}(2)) - @test Float32.(a_zeros) ≡ DiagonalMatrix(Zeros{Float32}(2)) - @test inv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) - @test inv.(a_zeros) ≡ DiagonalMatrix(Fill(inv(zero(elt)), 2)) - @test pinv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) - @test pinv.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) - @test sqrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) - @test sqrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) - if elt <: Real - @test cbrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) - @test cbrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) - end - @test exp.(a_ones) ≡ DiagonalMatrix(Fill(exp(one(elt)), 2)) - @test exp.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(exp(zero(elt)))}(2)) - @test cis.(a_ones) ≡ DiagonalMatrix(Fill(cis(one(elt)), 2)) - @test cis.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cis(zero(elt)))}(2)) - @test log.(a_ones) ≡ DiagonalMatrix(Zeros{typeof(log(one(elt)))}(2)) - @test log.(a_zeros) ≡ DiagonalMatrix(Fill(log(zero(elt)), 2)) - @test cos.(a_ones) ≡ DiagonalMatrix(Fill(cos(one(elt)), 2)) - @test cos.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cos(zero(elt)))}(2)) - @test sin.(a_ones) ≡ DiagonalMatrix(Fill(sin(one(elt)), 2)) - @test sin.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(sin(zero(elt)))}(2)) - @test tan.(a_ones) ≡ DiagonalMatrix(Fill(tan(one(elt)), 2)) - @test tan.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(tan(zero(elt)))}(2)) - @test sec.(a_ones) ≡ DiagonalMatrix(Fill(sec(one(elt)), 2)) - @test sec.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(sec(zero(elt)))}(2)) - @test cosh.(a_ones) ≡ DiagonalMatrix(Fill(cosh(one(elt)), 2)) - @test cosh.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cosh(zero(elt)))}(2)) - end - @testset "Array properties" begin - a = DiagonalMatrix(randn(elt, 2)) - @test !iszero(a) - - a = DiagonalMatrix(zeros(elt, 2)) - @test iszero(a) - - a = DiagonalMatrix(Zeros{elt}(2)) - @test iszero(a) - end - @testset "LinearAlgebra matrix properties" begin - @test ishermitian(DiagonalMatrix([1, 2])) - @test !ishermitian(DiagonalMatrix([1, 2], (2, 3))) - @test !ishermitian(DiagonalMatrix([1 + 1im, 2 + 2im])) - @test ishermitian(DiagonalMatrix([ones(2, 2), ones(3, 3)])) - @test !ishermitian(DiagonalMatrix([randn(2, 2), randn(3, 3)])) - - @test issymmetric(DiagonalMatrix([1, 2])) - @test !issymmetric(DiagonalMatrix([1, 2], (2, 3))) - @test issymmetric(DiagonalMatrix([1 + 1im, 2 + 2im])) - @test issymmetric(DiagonalMatrix([ones(2, 2), ones(3, 3)])) - @test !issymmetric(DiagonalMatrix([randn(2, 2), randn(3, 3)])) - @test !issymmetric(DiagonalMatrix([randn(2, 2), randn(2, 3)])) - - @test isposdef(DiagonalMatrix([1, 2])) - @test !isposdef(DiagonalMatrix([1, -2])) - @test !isposdef(DiagonalMatrix([1, 2], (2, 3))) - @test !isposdef(DiagonalMatrix([1 + 1im, 2 + 2im])) - @test isposdef(DiagonalMatrix([[1 0; 0 1], [2 0; 0 2]])) - @test !isposdef(DiagonalMatrix([randn(2, 2), randn(3, 3)])) - @test !isposdef(DiagonalMatrix([randn(2, 2), randn(2, 3)])) - end - @testset "LinearAlgebra matrix functions" begin - diag = randn(elt, 2) - a = DiagonalMatrix(diag) - @test tr(a) ≈ sum(diag) - @test det(a) ≈ prod(diag) - - # Use a positive diagonal in order to take the `log`. - diag = rand(elt, 2) - a = DiagonalMatrix(diag) - @test real(logdet(a)) ≈ real(sum(log, diag)) - @test imag(logdet(a)) ≈ rem2pi(imag(sum(log, diag)), RoundNearest) - - for f in [ - :exp, - :cis, - :log, - :sqrt, - :cos, - :sin, - :tan, - :csc, - :sec, - :cot, - :cosh, - :sinh, - :tanh, - :csch, - :sech, - :coth, - :acos, - :asin, - :atan, - :acot, - :asinh, - :atanh, - :acsch, - :asech, - ] - @eval begin - a = DiagonalMatrix(rand($elt, 2)) - @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + @testset "Basics" begin + a = fill(one(elt), 2, 3) + @test diaglength(a) == 2 + a = fill(one(elt)) + @test diaglength(a) == 1 end - end - - for f in [:acsc, :asec, :acosh, :acoth] - @eval begin - a = DiagonalMatrix(inv.(rand($elt, 2))) - @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + @testset "diagindices" begin + a = randn(elt, ()) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:1 + @test isempty(diagindices(IndexCartesian(), a)) + + for a in ( + randn(elt, (0,)), + randn(elt, (0, 0)), + randn(elt, (0, 3)), + randn(elt, (3, 0)), + randn(elt, (0, 0, 0)), + randn(elt, (3, 3, 0)), + ) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:0 + @test isempty(diagindices(IndexCartesian(), a)) + end + + a = randn(elt, (3,)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:3 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(1:3) + + a = randn(elt, (4,)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:4 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(1:4) + + for a in (randn(elt, (3, 3)), randn(elt, (3, 4))) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:4:9 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3)) + end + + a = randn(elt, (4, 3)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:5:11 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3)) + + for a in (randn(elt, (3, 3, 3)), randn(elt, (3, 3, 4))) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:13:27 + @test diagindices(IndexCartesian(), a) == + CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3)) + end + + a = randn(elt, (3, 4, 3)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:16:33 + @test diagindices(IndexCartesian(), a) == + CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3)) end - end - - if elt <: Real - a = DiagonalMatrix(randn(elt, 2)) - @test cbrt(a) ≈ DiagonalMatrix(cbrt.(diagview(a))) - end + @testset "DiagonalArray constructors" begin + v = randn(elt, 2) + @test DiagonalArray(v, 2, 2) ≡ + DiagonalArray(v, (2, 2)) ≡ + DiagonalArray(v, Base.OneTo(2), Base.OneTo(2)) ≡ + DiagonalArray(v, (Base.OneTo(2), Base.OneTo(2))) ≡ + DiagonalArray{elt}(v, 2, 2) ≡ + DiagonalArray{elt}(v, (2, 2)) ≡ + DiagonalArray{elt}(v, Base.OneTo(2), Base.OneTo(2)) ≡ + DiagonalArray{elt}(v, (Base.OneTo(2), Base.OneTo(2))) ≡ + DiagonalArray{elt, 2}(v, 2, 2) ≡ + DiagonalArray{elt, 2}(v, (2, 2)) ≡ + DiagonalArray{elt, 2}(v, Base.OneTo(2), Base.OneTo(2)) ≡ + DiagonalArray{elt, 2}(v, (Base.OneTo(2), Base.OneTo(2))) + @test size(DiagonalArray{elt}(undef, 2, 2)) ≡ + size(DiagonalArray{elt}(undef, (2, 2))) ≡ + size(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ + size(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) ≡ + size(DiagonalArray{elt, 2}(undef, 2, 2)) ≡ + size(DiagonalArray{elt, 2}(undef, (2, 2))) ≡ + size(DiagonalArray{elt, 2}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ + size(DiagonalArray{elt, 2}(undef, (Base.OneTo(2), Base.OneTo(2)))) + @test elt ≡ + eltype(DiagonalArray{elt}(undef, 2, 2)) ≡ + eltype(DiagonalArray{elt}(undef, (2, 2))) ≡ + eltype(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ + eltype(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) ≡ + eltype(DiagonalMatrix{elt}(undef, 2, 2)) ≡ + eltype(DiagonalMatrix{elt}(undef, (2, 2))) ≡ + eltype(DiagonalMatrix{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ + eltype(DiagonalMatrix{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) + + # Special constructors for immutable diagonal. + init = ShapeInitializer() + U = Zeros{UInt32, 2, Tuple{Base.OneTo{Int}, Base.OneTo{Int}}} + @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}}(init, (2, 2)) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}}(init, 2, 2) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}, U}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}, U}(init, Base.OneTo.((2, 2))...) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}, U}(init, (2, 2)) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}, U}(init, 2, 2) ≡ + DiagonalMatrix{UInt32, Base.OneTo{UInt32}, U}(init, Unstored(Zeros{UInt32}(2, 2))) + + init = ShapeInitializer() + @test DiagonalMatrix(Ones{elt}(2)) ≡ + DiagonalMatrix{elt, Ones{elt, 1, Tuple{Base.OneTo{Int}}}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{elt, Ones{elt, 1, Tuple{Base.OneTo{Int}}}}( + init, Base.OneTo.((2, 2))... + ) ≡ + DiagonalMatrix{elt, Ones{elt, 1, Tuple{Base.OneTo{Int}}}}(init, (2, 2)) ≡ + DiagonalMatrix{elt, Ones{elt, 1, Tuple{Base.OneTo{Int}}}}(init, 2, 2) ≡ + DiagonalMatrix{elt, Ones{elt, 1, Tuple{Base.OneTo{Int}}}}( + init, Unstored(Zeros{elt}(2, 2)) + ) + + init = ShapeInitializer() + @test_throws ErrorException DiagonalMatrix{elt, Vector{elt}}(init, Base.OneTo.((2, 2))) + @test_throws ErrorException DiagonalMatrix{elt, Vector{elt}}( + init, Base.OneTo.((2, 2))... + ) + @test_throws ErrorException DiagonalMatrix{elt, Vector{elt}}(init, (2, 2)) + @test_throws ErrorException DiagonalMatrix{elt, Vector{elt}}(init, 2, 2) + + # 0-dim constructors + v = randn(elt, 1) + @test DiagonalArray(v) ≡ + DiagonalArray(v, ()) ≡ + DiagonalArray{elt}(v) ≡ + DiagonalArray{elt}(v, ()) ≡ + DiagonalArray{elt, 0}(v) ≡ + DiagonalArray{elt, 0}(v, ()) + @test size(DiagonalArray{elt}(undef)) ≡ + size(DiagonalArray{elt}(undef, ())) ≡ + size(DiagonalArray{elt, 0}(undef)) ≡ + size(DiagonalArray{elt, 0}(undef, ())) + @test elt ≡ + eltype(DiagonalArray{elt}(undef)) ≡ + eltype(DiagonalArray{elt}(undef, ())) ≡ + eltype(DiagonalArray{elt, 0}(undef)) ≡ + eltype(DiagonalArray{elt, 0}(undef, ())) + + # Special constructors for immutable diagonal. + init = ShapeInitializer() + @test DiagonalArray{<:Any, 0}(Base.OneTo(UInt32(1))) ≡ + DiagonalArray{UInt32, 0, Base.OneTo{UInt32}}(init, ()) ≡ + DiagonalArray{UInt32, 0, Base.OneTo{UInt32}}(init) ≡ + DiagonalArray{UInt32, 0, Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}())) + end + @testset "0-dim operations" begin + diag = randn(elt, 1) + a = DiagonalArray(diag) + @test a[] == diag[1] + a[] = 2 + @test a[] == 2 + end + @testset "Conversion" begin + a = DiagonalMatrix(randn(elt, 2)) + @test DiagonalMatrix{elt}(a) ≡ a + @test DiagonalMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test DiagonalArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test DiagonalArray(a) ≡ a + @test AbstractMatrix{elt}(a) ≡ a + @test AbstractMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test AbstractArray{elt}(a) ≡ a + @test AbstractArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + end + @testset "Slicing" begin + # Slicing that preserves the diagonal structure. + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[:, :] + @test b isa DiagonalMatrix{elt, <:SubArray{elt, 1}} + @test diagview(b) ≡ view(diagview(a), :) + + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[Base.OneTo(2), Base.OneTo(2)] + @test b isa DiagonalMatrix{elt, <:SubArray{elt, 1}} + @test diagview(b) ≡ view(diagview(a), Base.OneTo(2)) + + a = DiagonalMatrix(randn(elt, 3)) + b = a[:, :] + @test typeof(b) == typeof(a) + @test diagview(b) == diagview(a) + + a = DiagonalMatrix(randn(elt, 3)) + b = a[Base.OneTo(2), Base.OneTo(2)] + @test typeof(b) == typeof(a) + @test diagview(b) == diagview(a)[Base.OneTo(2)] + + # Slicing that doesn't preserve the diagonal structure. + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[2:3, 2:3] + @test b isa SubArray + @test b == Matrix(a)[2:3, 2:3] + + a = DiagonalMatrix(randn(elt, 3)) + b = a[2:3, 2:3] + @test b isa SparseMatrixDOK{elt} + @test b == Matrix(a)[2:3, 2:3] + @test storedlength(b) == 2 + end + @testset "permutedims" begin + a = DiagonalArray(randn(elt, 2), (2, 3, 4)) + b = permutedims(a, (3, 1, 2)) + @test diagview(b) == diagview(a) + @test diagview(b) ≢ diagview(a) + @test size(b) === (4, 2, 3) + end + @testset "DerivableInterfaces.permuteddims" begin + a = DiagonalArray(randn(elt, 2), (2, 3, 4)) + b = permuteddims(a, (3, 1, 2)) + @test diagview(b) ≡ diagview(a) + @test size(b) === (4, 2, 3) + end + @testset "Broadcasting" begin + a = DiagonalArray(randn(elt, 2), (2, 3)) + b = DiagonalArray(randn(elt, 2), (2, 3)) + c = a .+ 2 .* b + @test c ≈ Array(a) + 2 * Array(b) + # Non-zero-preserving functions not supported yet. + @test_broken a .+ 2 + + c = DiagonalArray{elt}(undef, (2, 3)) + c .= a .+ 2 .* b + @test c ≈ Array(a) + 2 * Array(b) + + # Non-zero-preserving functions not supported yet. + c = DiagonalArray{elt}(undef, (2, 3)) + @test_broken c .= a .+ 2 + + a_ones = DiagonalMatrix(Ones{elt}(2)) + a_zeros = DiagonalMatrix(Zeros{elt}(2)) + @test identity.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test identity.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + @test complex.(a_ones) ≡ DiagonalMatrix(Ones{complex(elt)}(2)) + @test complex.(a_zeros) ≡ DiagonalMatrix(Zeros{complex(elt)}(2)) + @test Float32.(a_ones) ≡ DiagonalMatrix(Ones{Float32}(2)) + @test Float32.(a_zeros) ≡ DiagonalMatrix(Zeros{Float32}(2)) + @test inv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test inv.(a_zeros) ≡ DiagonalMatrix(Fill(inv(zero(elt)), 2)) + @test pinv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test pinv.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + @test sqrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test sqrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + if elt <: Real + @test cbrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test cbrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + end + @test exp.(a_ones) ≡ DiagonalMatrix(Fill(exp(one(elt)), 2)) + @test exp.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(exp(zero(elt)))}(2)) + @test cis.(a_ones) ≡ DiagonalMatrix(Fill(cis(one(elt)), 2)) + @test cis.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cis(zero(elt)))}(2)) + @test log.(a_ones) ≡ DiagonalMatrix(Zeros{typeof(log(one(elt)))}(2)) + @test log.(a_zeros) ≡ DiagonalMatrix(Fill(log(zero(elt)), 2)) + @test cos.(a_ones) ≡ DiagonalMatrix(Fill(cos(one(elt)), 2)) + @test cos.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cos(zero(elt)))}(2)) + @test sin.(a_ones) ≡ DiagonalMatrix(Fill(sin(one(elt)), 2)) + @test sin.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(sin(zero(elt)))}(2)) + @test tan.(a_ones) ≡ DiagonalMatrix(Fill(tan(one(elt)), 2)) + @test tan.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(tan(zero(elt)))}(2)) + @test sec.(a_ones) ≡ DiagonalMatrix(Fill(sec(one(elt)), 2)) + @test sec.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(sec(zero(elt)))}(2)) + @test cosh.(a_ones) ≡ DiagonalMatrix(Fill(cosh(one(elt)), 2)) + @test cosh.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cosh(zero(elt)))}(2)) + end + @testset "Array properties" begin + a = DiagonalMatrix(randn(elt, 2)) + @test !iszero(a) - a = DiagonalMatrix(randn(elt, 2)) - @test inv(a) ≈ DiagonalMatrix(inv.(diagview(a))) + a = DiagonalMatrix(zeros(elt, 2)) + @test iszero(a) - a = DiagonalMatrix(randn(elt, 2)) - @test pinv(a) ≈ DiagonalMatrix(pinv.(diagview(a))) - end - @testset "Matrix multiplication" begin - a1 = DiagonalArray{elt}(undef, (2, 3)) - a1[1, 1] = 11 - a1[2, 2] = 22 - a2 = DiagonalArray{elt}(undef, (3, 4)) - a2[1, 1] = 11 - a2[2, 2] = 22 - a2[3, 3] = 33 - a_dest = a1 * a2 - # TODO: Use `densearray` to make generic to GPU. - @test Array(a_dest) ≈ Array(a1) * Array(a2) - # TODO: Make this work with `ArrayLayouts`. - @test storedlength(a_dest) == 2 - @test a_dest isa DiagonalMatrix{elt} - - a_dest = DiagonalArray{elt}(undef, (2, 4)) - mul!(a_dest, a1, a2) - @test Array(a_dest) ≈ Array(a1) * Array(a2) - - a_dest = DiagonalArray(randn(elt, 2), (2, 4)) - a_dest′ = copy(a_dest) - mul!(a_dest′, a1, a2, 2, 3) - @test Array(a_dest′) ≈ Array(a1) * Array(a2) * 2 + Array(a_dest) * 3 - - # TODO: Make generic to GPU, use `allocate_randn`? - a2 = randn(elt, (3, 4)) - a_dest = a1 * a2 - # TODO: Use `densearray` to make generic to GPU. - @test Array(a_dest) ≈ Array(a1) * Array(a2) - @test storedlength(a_dest) == 8 - @test a_dest isa Matrix{elt} - - a2 = sparsezeros(elt, (3, 4)) - a2[1, 1] = 11 - a2[2, 2] = 22 - a2[3, 3] = 33 - a_dest = a1 * a2 - # TODO: Use `densearray` to make generic to GPU. - @test Array(a_dest) ≈ Array(a1) * Array(a2) - # TODO: Define `SparseMatrixDOK`. - # TODO: Make this work with `ArrayLayouts`. - @test storedlength(a_dest) == 2 - @test a_dest isa SparseArrayDOK{elt,2} - end - @testset "diagonal" begin - v = randn(2) - d = @inferred diagonal(v) - @test d isa Diagonal{eltype(v)} - @test diagview(d) === v - @test diagonaltype(v) === typeof(d) - - a = randn(2, 2) - d = @inferred diagonal(a) - @test d isa Diagonal{eltype(v)} - @test diagview(d) == diagview(a) - @test diagonaltype(a) === typeof(d) - - a = randn(3, 3) - @test getdiagindices(a, 2:3) == diagview(a)[2:3] - end - @testset "delta" begin - for (a, elt′) in ( - (delta(2, 2), Float64), - (delta(Base.OneTo(2), Base.OneTo(2)), Float64), - (δ(2, 2), Float64), - (δ(Base.OneTo(2), Base.OneTo(2)), Float64), - (delta((2, 2)), Float64), - (delta(Base.OneTo.((2, 2))), Float64), - (δ((2, 2)), Float64), - (δ(Base.OneTo.((2, 2))), Float64), - (delta(Bool, 2, 2), Bool), - (delta(Bool, Base.OneTo(2), Base.OneTo(2)), Bool), - (Delta{Bool}((2, 2)), Bool), - (Delta{Bool}(Base.OneTo.((2, 2))), Bool), - (δ(Bool, 2, 2), Bool), - (δ(Bool, Base.OneTo(2), Base.OneTo(2)), Bool), - (delta(Bool, (2, 2)), Bool), - (delta(Bool, Base.OneTo.((2, 2))), Bool), - (δ(Bool, (2, 2)), Bool), - (δ(Bool, Base.OneTo.((2, 2))), Bool), - ) - @test eltype(a) === elt′ - @test diaglength(a) == 2 - @test a isa DiagonalArray{elt′,2} - @test a isa DiagonalMatrix{elt′} - @test a isa Delta{elt′,2} - @test a isa DeltaMatrix{elt′} - @test size(a) == (2, 2) - @test diaglength(a) == 2 - @test storedlength(a) == 2 - @test a == DiagonalArray(ones(2), (2, 2)) - @test diagview(a) == ones(2) - @test diagview(a) isa Ones{elt′} - @test copy(a) ≡ a - - a′ = 2a - @test diagview(a′) == 2ones(2) - # TODO: Fix this. Mapping doesn't preserve - # the diagonal structure properly. - # https://github.com/ITensor/DiagonalArrays.jl/issues/7 - @test diagview(a′) isa Fill{promote_type(Int, elt′)} - @test a′ isa ScaledDelta{promote_type(Int, elt′),2} - @test a′ isa ScaledDeltaMatrix{promote_type(Int, elt′)} - - b = randn(elt, (2, 3)) - a_dest = a * b - @test a_dest ≈ Array(a) * Array(b) - - a_dest = a * a - @test a_dest ≈ Array(a) * Array(a) - @test diagview(a_dest) isa Ones{elt′} - end + a = DiagonalMatrix(Zeros{elt}(2)) + @test iszero(a) + end + @testset "LinearAlgebra matrix properties" begin + @test ishermitian(DiagonalMatrix([1, 2])) + @test !ishermitian(DiagonalMatrix([1, 2], (2, 3))) + @test !ishermitian(DiagonalMatrix([1 + 1im, 2 + 2im])) + @test ishermitian(DiagonalMatrix([ones(2, 2), ones(3, 3)])) + @test !ishermitian(DiagonalMatrix([randn(2, 2), randn(3, 3)])) + + @test issymmetric(DiagonalMatrix([1, 2])) + @test !issymmetric(DiagonalMatrix([1, 2], (2, 3))) + @test issymmetric(DiagonalMatrix([1 + 1im, 2 + 2im])) + @test issymmetric(DiagonalMatrix([ones(2, 2), ones(3, 3)])) + @test !issymmetric(DiagonalMatrix([randn(2, 2), randn(3, 3)])) + @test !issymmetric(DiagonalMatrix([randn(2, 2), randn(2, 3)])) + + @test isposdef(DiagonalMatrix([1, 2])) + @test !isposdef(DiagonalMatrix([1, -2])) + @test !isposdef(DiagonalMatrix([1, 2], (2, 3))) + @test !isposdef(DiagonalMatrix([1 + 1im, 2 + 2im])) + @test isposdef(DiagonalMatrix([[1 0; 0 1], [2 0; 0 2]])) + @test !isposdef(DiagonalMatrix([randn(2, 2), randn(3, 3)])) + @test !isposdef(DiagonalMatrix([randn(2, 2), randn(2, 3)])) + end + @testset "LinearAlgebra matrix functions" begin + diag = randn(elt, 2) + a = DiagonalMatrix(diag) + @test tr(a) ≈ sum(diag) + @test det(a) ≈ prod(diag) + + # Use a positive diagonal in order to take the `log`. + diag = rand(elt, 2) + a = DiagonalMatrix(diag) + @test real(logdet(a)) ≈ real(sum(log, diag)) + @test imag(logdet(a)) ≈ rem2pi(imag(sum(log, diag)), RoundNearest) + + for f in [ + :exp, + :cis, + :log, + :sqrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acot, + :asinh, + :atanh, + :acsch, + :asech, + ] + @eval begin + a = DiagonalMatrix(rand($elt, 2)) + @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + end + end + + for f in [:acsc, :asec, :acosh, :acoth] + @eval begin + a = DiagonalMatrix(inv.(rand($elt, 2))) + @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + end + end + + if elt <: Real + a = DiagonalMatrix(randn(elt, 2)) + @test cbrt(a) ≈ DiagonalMatrix(cbrt.(diagview(a))) + end + + a = DiagonalMatrix(randn(elt, 2)) + @test inv(a) ≈ DiagonalMatrix(inv.(diagview(a))) + + a = DiagonalMatrix(randn(elt, 2)) + @test pinv(a) ≈ DiagonalMatrix(pinv.(diagview(a))) + end + @testset "Matrix multiplication" begin + a1 = DiagonalArray{elt}(undef, (2, 3)) + a1[1, 1] = 11 + a1[2, 2] = 22 + a2 = DiagonalArray{elt}(undef, (3, 4)) + a2[1, 1] = 11 + a2[2, 2] = 22 + a2[3, 3] = 33 + a_dest = a1 * a2 + # TODO: Use `densearray` to make generic to GPU. + @test Array(a_dest) ≈ Array(a1) * Array(a2) + # TODO: Make this work with `ArrayLayouts`. + @test storedlength(a_dest) == 2 + @test a_dest isa DiagonalMatrix{elt} + + a_dest = DiagonalArray{elt}(undef, (2, 4)) + mul!(a_dest, a1, a2) + @test Array(a_dest) ≈ Array(a1) * Array(a2) + + a_dest = DiagonalArray(randn(elt, 2), (2, 4)) + a_dest′ = copy(a_dest) + mul!(a_dest′, a1, a2, 2, 3) + @test Array(a_dest′) ≈ Array(a1) * Array(a2) * 2 + Array(a_dest) * 3 + + # TODO: Make generic to GPU, use `allocate_randn`? + a2 = randn(elt, (3, 4)) + a_dest = a1 * a2 + # TODO: Use `densearray` to make generic to GPU. + @test Array(a_dest) ≈ Array(a1) * Array(a2) + @test storedlength(a_dest) == 8 + @test a_dest isa Matrix{elt} + + a2 = sparsezeros(elt, (3, 4)) + a2[1, 1] = 11 + a2[2, 2] = 22 + a2[3, 3] = 33 + a_dest = a1 * a2 + # TODO: Use `densearray` to make generic to GPU. + @test Array(a_dest) ≈ Array(a1) * Array(a2) + # TODO: Define `SparseMatrixDOK`. + # TODO: Make this work with `ArrayLayouts`. + @test storedlength(a_dest) == 2 + @test a_dest isa SparseArrayDOK{elt, 2} + end + @testset "diagonal" begin + v = randn(2) + d = @inferred diagonal(v) + @test d isa Diagonal{eltype(v)} + @test diagview(d) === v + @test diagonaltype(v) === typeof(d) + + a = randn(2, 2) + d = @inferred diagonal(a) + @test d isa Diagonal{eltype(v)} + @test diagview(d) == diagview(a) + @test diagonaltype(a) === typeof(d) + + a = randn(3, 3) + @test getdiagindices(a, 2:3) == diagview(a)[2:3] + end + @testset "delta" begin + for (a, elt′) in ( + (delta(2, 2), Float64), + (delta(Base.OneTo(2), Base.OneTo(2)), Float64), + (δ(2, 2), Float64), + (δ(Base.OneTo(2), Base.OneTo(2)), Float64), + (delta((2, 2)), Float64), + (delta(Base.OneTo.((2, 2))), Float64), + (δ((2, 2)), Float64), + (δ(Base.OneTo.((2, 2))), Float64), + (delta(Bool, 2, 2), Bool), + (delta(Bool, Base.OneTo(2), Base.OneTo(2)), Bool), + (Delta{Bool}((2, 2)), Bool), + (Delta{Bool}(Base.OneTo.((2, 2))), Bool), + (δ(Bool, 2, 2), Bool), + (δ(Bool, Base.OneTo(2), Base.OneTo(2)), Bool), + (delta(Bool, (2, 2)), Bool), + (delta(Bool, Base.OneTo.((2, 2))), Bool), + (δ(Bool, (2, 2)), Bool), + (δ(Bool, Base.OneTo.((2, 2))), Bool), + ) + @test eltype(a) === elt′ + @test diaglength(a) == 2 + @test a isa DiagonalArray{elt′, 2} + @test a isa DiagonalMatrix{elt′} + @test a isa Delta{elt′, 2} + @test a isa DeltaMatrix{elt′} + @test size(a) == (2, 2) + @test diaglength(a) == 2 + @test storedlength(a) == 2 + @test a == DiagonalArray(ones(2), (2, 2)) + @test diagview(a) == ones(2) + @test diagview(a) isa Ones{elt′} + @test copy(a) ≡ a + + a′ = 2a + @test diagview(a′) == 2ones(2) + # TODO: Fix this. Mapping doesn't preserve + # the diagonal structure properly. + # https://github.com/ITensor/DiagonalArrays.jl/issues/7 + @test diagview(a′) isa Fill{promote_type(Int, elt′)} + @test a′ isa ScaledDelta{promote_type(Int, elt′), 2} + @test a′ isa ScaledDeltaMatrix{promote_type(Int, elt′)} + + b = randn(elt, (2, 3)) + a_dest = a * b + @test a_dest ≈ Array(a) * Array(b) + + a_dest = a * a + @test a_dest ≈ Array(a) * Array(a) + @test diagview(a_dest) isa Ones{elt′} + end + end end - end end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index e281a83..16241e3 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -3,26 +3,26 @@ using LinearAlgebra: Diagonal using DiagonalArrays: DiagonalArrays, DeltaMatrix, ScaledDeltaMatrix, δ, dual using FillArrays: Ones using MatrixAlgebraKit: - eig_full, - eig_vals, - eigh_full, - eigh_vals, - left_orth, - left_polar, - lq_compact, - lq_full, - qr_compact, - qr_full, - right_orth, - right_polar, - svd_compact, - svd_full, - svd_vals + eig_full, + eig_vals, + eigh_full, + eigh_vals, + left_orth, + left_polar, + lq_compact, + lq_full, + qr_compact, + qr_full, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_vals using StableRNGs: StableRNG struct SU2 <: AbstractUnitRange{Int} - j::Int - isdual::Bool + j::Int + isdual::Bool end SU2(j::Int) = SU2(j, false) Base.:(==)(s1::SU2, s2::SU2) = ((s1.j == s2.j) && (s1.isdual == s2.isdual)) @@ -31,183 +31,183 @@ Base.last(s::SU2) = 2 * s.j + 1 DiagonalArrays.dual(s::SU2) = SU2(s.j, !s.isdual) @testset "MatrixAlgebraKitExt" begin - @testset "DeltaMatrix factorizations (eltype=$elt)" for elt in ( - Float32, Float64, ComplexF32, ComplexF64 - ) - @testset "SVD" begin - for f in (svd_compact, svd_full) - ax = SU2(2) - a = δ(elt, (ax, ax)) - u, s, v = f(a) - @test u * s * v ≡ a - @test u ≡ δ(elt, (ax, dual(ax))) - @test s ≡ δ(real(elt), (ax, ax)) - @test v ≡ δ(elt, (dual(ax), ax)) - end + @testset "DeltaMatrix factorizations (eltype=$elt)" for elt in ( + Float32, Float64, ComplexF32, ComplexF64, + ) + @testset "SVD" begin + for f in (svd_compact, svd_full) + ax = SU2(2) + a = δ(elt, (ax, ax)) + u, s, v = f(a) + @test u * s * v ≡ a + @test u ≡ δ(elt, (ax, dual(ax))) + @test s ≡ δ(real(elt), (ax, ax)) + @test v ≡ δ(elt, (dual(ax), ax)) + end + end + @testset "SVD values" begin + ax = SU2(2) + a = δ(elt, (ax, ax)) + s = svd_vals(a) + @test s ≡ Ones(real(elt), length(ax)) + end + @testset "left orth" begin + for f in (left_orth, left_polar, qr_compact, qr_full) + ax = SU2(2) + a = δ(elt, (ax, ax)) + q, r = f(a) + @test q * r ≡ a + @test q ≡ δ(elt, (ax, dual(ax))) + @test r ≡ δ(elt, (ax, ax)) + end + end + @testset "right orth" begin + for f in (lq_compact, lq_full, right_orth, right_polar) + ax = SU2(2) + a = δ(elt, (ax, ax)) + l, q = f(a) + @test l * q ≡ a + @test l ≡ δ(elt, (ax, ax)) + @test q ≡ δ(elt, (dual(ax), ax)) + end + end + @testset "Eigendecomposition" begin + ax = SU2(2) + a = δ(elt, (dual(ax), ax)) + d, v = eig_full(a) + @test a * v ≡ v * d + @test d ≡ δ(complex(elt), (dual(ax), ax)) + @test v ≡ δ(complex(elt), (dual(ax), ax)) + end + @testset "Hermitian eigendecomposition" begin + ax = SU2(2) + a = δ(elt, (dual(ax), ax)) + d, v = eigh_full(a) + @test a * v ≡ v * d + @test d ≡ δ(real(elt), (dual(ax), ax)) + @test v ≡ δ(elt, (dual(ax), ax)) + end + @testset "Eigenvalues" begin + ax = SU2(2) + a = δ(elt, (dual(ax), ax)) + d = eig_vals(a) + @test d ≡ Ones{complex(elt)}(length(ax)) + end + @testset "Hermitian eigenvalues" begin + ax = SU2(2) + a = δ(elt, (dual(ax), ax)) + d = eigh_vals(a) + @test d ≡ Ones{real(elt)}(length(ax)) + end + @testset "left null" begin + ax = SU2(2) + a = δ(elt, (ax, ax)) + @test_broken left_null(a) + end + @testset "right null" begin + ax = SU2(2) + a = δ(elt, (ax, ax)) + @test_broken right_null(a) + end end - @testset "SVD values" begin - ax = SU2(2) - a = δ(elt, (ax, ax)) - s = svd_vals(a) - @test s ≡ Ones(real(elt), length(ax)) - end - @testset "left orth" begin - for f in (left_orth, left_polar, qr_compact, qr_full) - ax = SU2(2) - a = δ(elt, (ax, ax)) - q, r = f(a) - @test q * r ≡ a - @test q ≡ δ(elt, (ax, dual(ax))) - @test r ≡ δ(elt, (ax, ax)) - end - end - @testset "right orth" begin - for f in (lq_compact, lq_full, right_orth, right_polar) - ax = SU2(2) - a = δ(elt, (ax, ax)) - l, q = f(a) - @test l * q ≡ a - @test l ≡ δ(elt, (ax, ax)) - @test q ≡ δ(elt, (dual(ax), ax)) - end - end - @testset "Eigendecomposition" begin - ax = SU2(2) - a = δ(elt, (dual(ax), ax)) - d, v = eig_full(a) - @test a * v ≡ v * d - @test d ≡ δ(complex(elt), (dual(ax), ax)) - @test v ≡ δ(complex(elt), (dual(ax), ax)) - end - @testset "Hermitian eigendecomposition" begin - ax = SU2(2) - a = δ(elt, (dual(ax), ax)) - d, v = eigh_full(a) - @test a * v ≡ v * d - @test d ≡ δ(real(elt), (dual(ax), ax)) - @test v ≡ δ(elt, (dual(ax), ax)) - end - @testset "Eigenvalues" begin - ax = SU2(2) - a = δ(elt, (dual(ax), ax)) - d = eig_vals(a) - @test d ≡ Ones{complex(elt)}(length(ax)) - end - @testset "Hermitian eigenvalues" begin - ax = SU2(2) - a = δ(elt, (dual(ax), ax)) - d = eigh_vals(a) - @test d ≡ Ones{real(elt)}(length(ax)) - end - @testset "left null" begin - ax = SU2(2) - a = δ(elt, (ax, ax)) - @test_broken left_null(a) - end - @testset "right null" begin - ax = SU2(2) - a = δ(elt, (ax, ax)) - @test_broken right_null(a) - end - end - @testset "ScaledDeltaMatrix factorizations (eltype=$elt)" for elt in ( - Float32, Float64, ComplexF32, ComplexF64 - ) - @testset "SVD" begin - for f in (svd_compact, svd_full) - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, elt) - a = scale * δ(elt, (ax, ax)) - u, s, v = f(a) - @test u * s * v ≡ a - @test u ≡ δ(elt, (ax, dual(ax))) - @test s ≡ abs(scale) * δ(real(elt), (ax, ax)) - @test v ≡ sign(scale) * δ(elt, (dual(ax), ax)) - end - end - @testset "SVD values" begin - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, elt) - a = scale * δ(elt, (ax, ax)) - s = svd_vals(a) - @test s ≡ abs(scale) * Ones(real(elt), length(ax)) - end - @testset "left orth" begin - for f in (left_orth, left_polar, qr_compact, qr_full) - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, elt) - a = scale * δ(elt, (ax, ax)) - q, r = f(a) - @test q * r ≡ a - @test q ≡ sign(scale) * δ(elt, (ax, dual(ax))) - @test r ≡ abs(scale) * δ(elt, (ax, ax)) - end - end - @testset "right orth" begin - for f in (lq_compact, lq_full, right_orth, right_polar) - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, elt) - a = scale * δ(elt, (ax, ax)) - l, q = f(a) - @test l * q ≡ a - @test l ≡ abs(scale) * δ(elt, (ax, ax)) - @test q ≡ sign(scale) * δ(elt, (dual(ax), ax)) - end - end - @testset "Eigendecomposition" begin - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, elt) - a = scale * δ(elt, (dual(ax), ax)) - d, v = eig_full(a) - @test a * v ≡ v * d - @test d ≡ scale * δ(complex(elt), (dual(ax), ax)) - @test v ≡ δ(complex(elt), (dual(ax), ax)) - end - @testset "Hermitian eigendecomposition" begin - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, real(elt)) - a = scale * δ(elt, (dual(ax), ax)) - d, v = eigh_full(a) - @test a * v ≡ v * d - @test d ≡ scale * δ(real(elt), (dual(ax), ax)) - @test v ≡ δ(elt, (dual(ax), ax)) - end - @testset "Eigenvalues" begin - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, elt) - a = scale * δ(elt, (dual(ax), ax)) - d = eig_vals(a) - @test d ≡ scale * Ones{complex(elt)}(length(ax)) - end - @testset "Hermitian eigenvalues" begin - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, real(elt)) - a = scale * δ(elt, (dual(ax), ax)) - d = eigh_vals(a) - @test d ≡ scale * Ones{real(elt)}(length(ax)) - end - @testset "left null" begin - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, real(elt)) - a = scale * δ(elt, (ax, ax)) - @test_broken left_null(a) - end - @testset "right null" begin - ax = SU2(2) - rng = StableRNG(1234) - scale = randn(rng, real(elt)) - a = scale * δ(elt, (ax, ax)) - @test_broken right_null(a) + @testset "ScaledDeltaMatrix factorizations (eltype=$elt)" for elt in ( + Float32, Float64, ComplexF32, ComplexF64, + ) + @testset "SVD" begin + for f in (svd_compact, svd_full) + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, elt) + a = scale * δ(elt, (ax, ax)) + u, s, v = f(a) + @test u * s * v ≡ a + @test u ≡ δ(elt, (ax, dual(ax))) + @test s ≡ abs(scale) * δ(real(elt), (ax, ax)) + @test v ≡ sign(scale) * δ(elt, (dual(ax), ax)) + end + end + @testset "SVD values" begin + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, elt) + a = scale * δ(elt, (ax, ax)) + s = svd_vals(a) + @test s ≡ abs(scale) * Ones(real(elt), length(ax)) + end + @testset "left orth" begin + for f in (left_orth, left_polar, qr_compact, qr_full) + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, elt) + a = scale * δ(elt, (ax, ax)) + q, r = f(a) + @test q * r ≡ a + @test q ≡ sign(scale) * δ(elt, (ax, dual(ax))) + @test r ≡ abs(scale) * δ(elt, (ax, ax)) + end + end + @testset "right orth" begin + for f in (lq_compact, lq_full, right_orth, right_polar) + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, elt) + a = scale * δ(elt, (ax, ax)) + l, q = f(a) + @test l * q ≡ a + @test l ≡ abs(scale) * δ(elt, (ax, ax)) + @test q ≡ sign(scale) * δ(elt, (dual(ax), ax)) + end + end + @testset "Eigendecomposition" begin + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, elt) + a = scale * δ(elt, (dual(ax), ax)) + d, v = eig_full(a) + @test a * v ≡ v * d + @test d ≡ scale * δ(complex(elt), (dual(ax), ax)) + @test v ≡ δ(complex(elt), (dual(ax), ax)) + end + @testset "Hermitian eigendecomposition" begin + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, real(elt)) + a = scale * δ(elt, (dual(ax), ax)) + d, v = eigh_full(a) + @test a * v ≡ v * d + @test d ≡ scale * δ(real(elt), (dual(ax), ax)) + @test v ≡ δ(elt, (dual(ax), ax)) + end + @testset "Eigenvalues" begin + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, elt) + a = scale * δ(elt, (dual(ax), ax)) + d = eig_vals(a) + @test d ≡ scale * Ones{complex(elt)}(length(ax)) + end + @testset "Hermitian eigenvalues" begin + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, real(elt)) + a = scale * δ(elt, (dual(ax), ax)) + d = eigh_vals(a) + @test d ≡ scale * Ones{real(elt)}(length(ax)) + end + @testset "left null" begin + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, real(elt)) + a = scale * δ(elt, (ax, ax)) + @test_broken left_null(a) + end + @testset "right null" begin + ax = SU2(2) + rng = StableRNG(1234) + scale = randn(rng, real(elt)) + a = scale * δ(elt, (ax, ax)) + @test_broken right_null(a) + end end - end end