diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 4c49a86..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options -style = "blue" -indent = 2 diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 456fa05..0614de9 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -2,7 +2,7 @@ name: "CompatHelper" on: schedule: - - cron: 0 0 * * * + - cron: '0 0 * * *' workflow_dispatch: permissions: contents: write diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index 3f78afc..1525861 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,11 +1,14 @@ name: "Format Check" on: - push: - branches: - - 'main' - tags: '*' - pull_request: + pull_request_target: + paths: ['**/*.jl'] + types: [opened, synchronize, reopened, ready_for_review] + +permissions: + contents: read + actions: write + pull-requests: write jobs: format-check: diff --git a/.gitignore b/.gitignore index 10593a9..7085ca8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,10 @@ .vscode/ Manifest.toml benchmark/*.json +dev/ +docs/LocalPreferences.toml docs/Manifest.toml docs/build/ docs/src/index.md +examples/LocalPreferences.toml +test/LocalPreferences.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88bc8b4..3fc4743 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ ci: - skip: [julia-formatter] + skip: [runic] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -11,7 +11,7 @@ repos: - id: end-of-file-fixer exclude_types: [markdown] # incompatible with Literate.jl -- repo: "https://github.com/domluna/JuliaFormatter.jl" - rev: v2.1.6 +- repo: https://github.com/fredrikekre/runic-pre-commit + rev: v2.0.1 hooks: - - id: "julia-formatter" + - id: runic diff --git a/Project.toml b/Project.toml index ccf6d09..dae7972 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.4.3" +version = "0.4.4" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -13,20 +13,20 @@ TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +[weakdeps] +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" + +[extensions] +TensorAlgebraTensorOperationsExt = "TensorOperations" + [compat] ArrayLayouts = "1.10.4" BlockArrays = "1.7.2" EllipsisNotation = "1.8.0" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2, 0.3, 0.4" -TensorProducts = "0.1.5" TensorOperations = "5" +TensorProducts = "0.1.5" TupleTools = "1.6.0" TypeParameterAccessors = "0.2.1, 0.3, 0.4" julia = "1.10" - -[weakdeps] -TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" - -[extensions] -TensorAlgebraTensorOperationsExt = "TensorOperations" diff --git a/README.md b/README.md index 5d687e5..830f2b0 100644 --- a/README.md +++ b/README.md @@ -25,11 +25,11 @@ This step is only required once. ```julia julia> using Pkg: Pkg -julia> Pkg.Registry.add(url="https://github.com/ITensor/ITensorRegistry") +julia> Pkg.Registry.add(url = "https://github.com/ITensor/ITensorRegistry") ``` or: ```julia -julia> Pkg.Registry.add(url="git@github.com:ITensor/ITensorRegistry.git") +julia> Pkg.Registry.add(url = "git@github.com:ITensor/ITensorRegistry.git") ``` if you want to use SSH credentials, which can make it so you don't have to enter your Github ursername and password when registering packages. diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 2fb47cd..faa46b3 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -15,12 +15,12 @@ Ts = (Float64, ComplexF64) algs = (TensorAlgebra.Matricize(),) for alg in algs - alg_suite = contraction_suite[alg] = BenchmarkGroup() - for T in Ts - alg_suite[T] = BenchmarkGroup() + alg_suite = contraction_suite[alg] = BenchmarkGroup() + for T in Ts + alg_suite[T] = BenchmarkGroup() - for (i, line) in enumerate(eachline(CONTRACTIONS_PATH)) - alg_suite[T][i] = generate_contract_benchmark(line; T, alg) + for (i, line) in enumerate(eachline(CONTRACTIONS_PATH)) + alg_suite[T][i] = generate_contract_benchmark(line; T, alg) + end end - end end diff --git a/benchmark/contractions.jl b/benchmark/contractions.jl index 67d0f4f..7304dc1 100644 --- a/benchmark/contractions.jl +++ b/benchmark/contractions.jl @@ -1,75 +1,75 @@ function extract_contract_labels(contraction::AbstractString) - symbolsC = match(r"C\[([^\]]*)\]", contraction) - labelsC = split(symbolsC.captures[1], ","; keepempty=false) - symbolsA = match(r"A\[([^\]]*)\]", contraction) - labelsA = split(symbolsA.captures[1], ","; keepempty=false) - symbolsB = match(r"B\[([^\]]*)\]", contraction) - labelsB = split(symbolsB.captures[1], ","; keepempty=false) - return labelsC, labelsA, labelsB + symbolsC = match(r"C\[([^\]]*)\]", contraction) + labelsC = split(symbolsC.captures[1], ","; keepempty = false) + symbolsA = match(r"A\[([^\]]*)\]", contraction) + labelsA = split(symbolsA.captures[1], ","; keepempty = false) + symbolsB = match(r"B\[([^\]]*)\]", contraction) + labelsB = split(symbolsB.captures[1], ","; keepempty = false) + return labelsC, labelsA, labelsB end function generate_contract_benchmark( - line::AbstractString; elt=Float64, alg=default_contract_alg(), do_alpha=true, do_beta=true -) - line_split = split(line, " & ") - @assert length(line_split) == 2 "Invalid line format:\n$line" - contraction, sizes = line_split + line::AbstractString; elt = Float64, alg = default_contract_alg(), do_alpha = true, do_beta = true + ) + line_split = split(line, " & ") + @assert length(line_split) == 2 "Invalid line format:\n$line" + contraction, sizes = line_split - # extract labels - labelsC, labelsA, labelsB = map(Tuple, extract_contract_labels(contraction)) - # pA, pB, pC = TensorOperations.contract_indices( - # tuple(labelsA...), tuple(labelsB...), tuple(labelsC...) - # ) + # extract labels + labelsC, labelsA, labelsB = map(Tuple, extract_contract_labels(contraction)) + # pA, pB, pC = TensorOperations.contract_indices( + # tuple(labelsA...), tuple(labelsB...), tuple(labelsC...) + # ) - # extract sizes - subsizes = Dict{String,Int}() - for (label, sz) in split.(split(sizes, "; "; keepempty=false), Ref(":")) - subsizes[label] = parse(Int, sz) - end - szA = getindex.(Ref(subsizes), labelsA) - szB = getindex.(Ref(subsizes), labelsB) - szC = getindex.(Ref(subsizes), labelsC) - setup_tensors() = (rand(elt, szA...), rand(elt, szB...), rand(elt, szC...)) + # extract sizes + subsizes = Dict{String, Int}() + for (label, sz) in split.(split(sizes, "; "; keepempty = false), Ref(":")) + subsizes[label] = parse(Int, sz) + end + szA = getindex.(Ref(subsizes), labelsA) + szB = getindex.(Ref(subsizes), labelsB) + szC = getindex.(Ref(subsizes), labelsC) + setup_tensors() = (rand(elt, szA...), rand(elt, szB...), rand(elt, szC...)) - if do_alpha && do_beta - α, β = rand(elt, 2) - return @benchmarkable( - contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB, $α, $β), - setup = ((A, B, C) = $setup_tensors()), - evals = 1 - ) - elseif do_alpha - α = rand(elt) - return @benchmarkable( - contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB, $α), - setup = ((A, B, C) = $setup_tensors()), - evals = 1 - ) - elseif do_beta - β = rand(elt) - return @benchmarkable( - contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB, true, $β), - setup = ((A, B, C) = $setup_tensors()), - evals = 1 - ) - else - return @benchmarkable( - contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB), - setup = ((A, B, C) = $setup_tensors()), - evals = 1 - ) - end + if do_alpha && do_beta + α, β = rand(elt, 2) + return @benchmarkable( + contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB, $α, $β), + setup = ((A, B, C) = $setup_tensors()), + evals = 1 + ) + elseif do_alpha + α = rand(elt) + return @benchmarkable( + contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB, $α), + setup = ((A, B, C) = $setup_tensors()), + evals = 1 + ) + elseif do_beta + β = rand(elt) + return @benchmarkable( + contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB, true, $β), + setup = ((A, B, C) = $setup_tensors()), + evals = 1 + ) + else + return @benchmarkable( + contract!($alg, C, $labelsC, A, $labelsA, B, $labelsB), + setup = ((A, B, C) = $setup_tensors()), + evals = 1 + ) + end end function compute_contract_ops(line::AbstractString) - line_split = split(line, " & ") - @assert length(line_split) == 2 "Invalid line format:\n$line" - _, sizes = line_split + line_split = split(line, " & ") + @assert length(line_split) == 2 "Invalid line format:\n$line" + _, sizes = line_split - # extract sizes - subsizes = Dict{String,Int}() - for (label, sz) in split.(split(sizes, "; "; keepempty=false), Ref("=")) - subsizes[label] = parse(Int, sz) - end - return prod(collect(values(subsizes))) + # extract sizes + subsizes = Dict{String, Int}() + for (label, sz) in split.(split(sizes, "; "; keepempty = false), Ref("=")) + subsizes[label] = parse(Int, sz) + end + return prod(collect(values(subsizes))) end diff --git a/docs/make.jl b/docs/make.jl index 6e1aaf2..5f2f650 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,22 +1,24 @@ using TensorAlgebra: TensorAlgebra using Documenter: Documenter, DocMeta, deploydocs, makedocs -DocMeta.setdocmeta!(TensorAlgebra, :DocTestSetup, :(using TensorAlgebra); recursive=true) +DocMeta.setdocmeta!( + TensorAlgebra, :DocTestSetup, :(using TensorAlgebra); recursive = true +) include("make_index.jl") makedocs(; - modules=[TensorAlgebra], - authors="ITensor developers and contributors", - sitename="TensorAlgebra.jl", - format=Documenter.HTML(; - canonical="https://itensor.github.io/TensorAlgebra.jl", - edit_link="main", - assets=["assets/favicon.ico", "assets/extras.css"], - ), - pages=["Home" => "index.md", "Reference" => "reference.md"], + modules = [TensorAlgebra], + authors = "ITensor developers and contributors", + sitename = "TensorAlgebra.jl", + format = Documenter.HTML(; + canonical = "https://itensor.github.io/TensorAlgebra.jl", + edit_link = "main", + assets = ["assets/favicon.ico", "assets/extras.css"], + ), + pages = ["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; - repo="github.com/ITensor/TensorAlgebra.jl", devbranch="main", push_preview=true + repo = "github.com/ITensor/TensorAlgebra.jl", devbranch = "main", push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index 2832b29..ca4e0fd 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -2,20 +2,20 @@ using Literate: Literate using TensorAlgebra: TensorAlgebra function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ ```@raw html Flatiron Center for Computational Quantum Physics logo. Flatiron Center for Computational Quantum Physics logo. ``` """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(TensorAlgebra), "examples", "README.jl"), - joinpath(pkgdir(TensorAlgebra), "docs", "src"); - flavor=Literate.DocumenterFlavor(), - name="index", - postprocess=ccq_logo, + joinpath(pkgdir(TensorAlgebra), "examples", "README.jl"), + joinpath(pkgdir(TensorAlgebra), "docs", "src"); + flavor = Literate.DocumenterFlavor(), + name = "index", + postprocess = ccq_logo, ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index dde3d91..25e7379 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -2,20 +2,20 @@ using Literate: Literate using TensorAlgebra: TensorAlgebra function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ Flatiron Center for Computational Quantum Physics logo. """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(TensorAlgebra), "examples", "README.jl"), - joinpath(pkgdir(TensorAlgebra)); - flavor=Literate.CommonMarkFlavor(), - name="README", - postprocess=ccq_logo, + joinpath(pkgdir(TensorAlgebra), "examples", "README.jl"), + joinpath(pkgdir(TensorAlgebra)); + flavor = Literate.CommonMarkFlavor(), + name = "README", + postprocess = ccq_logo, ) diff --git a/examples/README.jl b/examples/README.jl index 6ef33c9..c03df58 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -1,5 +1,5 @@ # # TensorAlgebra.jl -# +# # [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/TensorAlgebra.jl/stable/) # [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/TensorAlgebra.jl/dev/) # [![Build Status](https://github.com/ITensor/TensorAlgebra.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/TensorAlgebra.jl/actions/workflows/Tests.yml?query=branch%3Amain) @@ -22,13 +22,13 @@ ```julia julia> using Pkg: Pkg -julia> Pkg.Registry.add(url="https://github.com/ITensor/ITensorRegistry") +julia> Pkg.Registry.add(url = "https://github.com/ITensor/ITensorRegistry") ``` =# # or: #= ```julia -julia> Pkg.Registry.add(url="git@github.com:ITensor/ITensorRegistry.git") +julia> Pkg.Registry.add(url = "git@github.com:ITensor/ITensorRegistry.git") ``` =# # if you want to use SSH credentials, which can make it so you don't have to enter your Github ursername and password when registering packages. diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl index 2edf43a..cc05324 100644 --- a/ext/TensorAlgebraTensorOperationsExt.jl +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -9,8 +9,8 @@ using TensorOperations: TensorOperations, AbstractBackend, DefaultBackend, Index Wrapper type for making a TensorOperations backend work as a TensorAlgebra algorithm. """ -struct TensorOperationsAlgorithm{B<:AbstractBackend} <: Algorithm - backend::B +struct TensorOperationsAlgorithm{B <: AbstractBackend} <: Algorithm + backend::B end TensorAlgebra.Algorithm(backend::AbstractBackend) = TensorOperationsAlgorithm(backend) @@ -18,11 +18,11 @@ TensorAlgebra.Algorithm(backend::AbstractBackend) = TensorOperationsAlgorithm(ba trivtuple(n) = ntuple(identity, n) function _index2tuple(p::BlockedPermutation{2}) - N₁, N₂ = blocklengths(p) - return ( - TupleTools.getindices(Tuple(p), trivtuple(N₁)), - TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)), - ) + N₁, N₂ = blocklengths(p) + return ( + TupleTools.getindices(Tuple(p), trivtuple(N₁)), + TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)), + ) end _blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...) @@ -32,121 +32,121 @@ _blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...) # not in-place function TensorAlgebra.contract( - algorithm::TensorOperationsAlgorithm, - bipermAB::BlockedPermutation, - A::AbstractArray, - bipermA::BlockedPermutation, - B::AbstractArray, - bipermB::BlockedPermutation, - α::Number, -) - pA = _index2tuple(bipermA) - pB = _index2tuple(bipermB) - pAB = _index2tuple(bipermAB) - return TensorOperations.tensorcontract( - A, pA, false, B, pB, false, pAB, α, algorithm.backend - ) + algorithm::TensorOperationsAlgorithm, + bipermAB::BlockedPermutation, + A::AbstractArray, + bipermA::BlockedPermutation, + B::AbstractArray, + bipermB::BlockedPermutation, + α::Number, + ) + pA = _index2tuple(bipermA) + pB = _index2tuple(bipermB) + pAB = _index2tuple(bipermAB) + return TensorOperations.tensorcontract( + A, pA, false, B, pB, false, pAB, α, algorithm.backend + ) end function TensorAlgebra.contract( - algorithm::TensorOperationsAlgorithm, - labelsC, - A::AbstractArray, - labelsA, - B::AbstractArray, - labelsB, - α::Number, -) - pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC) - return tensorcontract(A, pA, false, B, pB, false, pAB, α, algorithm.backend) + algorithm::TensorOperationsAlgorithm, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, + ) + pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC) + return tensorcontract(A, pA, false, B, pB, false, pAB, α, algorithm.backend) end # in-place function TensorAlgebra.contractadd!( - algorithm::TensorOperationsAlgorithm, - C::AbstractArray, - bipermAB::BlockedPermutation, - A::AbstractArray, - bipermA::BlockedPermutation, - B::AbstractArray, - bipermB::BlockedPermutation, - α::Number, - β::Number, -) - pA = _index2tuple(bipermA) - pB = _index2tuple(bipermB) - pAB = _index2tuple(bipermAB) - return TensorOperations.tensorcontract!( - C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend - ) + algorithm::TensorOperationsAlgorithm, + C::AbstractArray, + bipermAB::BlockedPermutation, + A::AbstractArray, + bipermA::BlockedPermutation, + B::AbstractArray, + bipermB::BlockedPermutation, + α::Number, + β::Number, + ) + pA = _index2tuple(bipermA) + pB = _index2tuple(bipermB) + pAB = _index2tuple(bipermAB) + return TensorOperations.tensorcontract!( + C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend + ) end function TensorAlgebra.contractadd!( - algorithm::TensorOperationsAlgorithm, - C::AbstractArray, - labelsC, - A::AbstractArray, - labelsA, - B::AbstractArray, - labelsB, - α::Number, - β::Number, -) - pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC) - return TensorOperations.tensorcontract!( - C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend - ) + algorithm::TensorOperationsAlgorithm, + C::AbstractArray, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, + β::Number, + ) + pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC) + return TensorOperations.tensorcontract!( + C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend + ) end # Using TensorAlgebra implementations as TensorOperations backends # ---------------------------------------------------------------- function TensorOperations.tensorcontract!( - C::AbstractArray, - A::AbstractArray, - pA::Index2Tuple, - conjA::Bool, - B::AbstractArray, - pB::Index2Tuple, - conjB::Bool, - pAB::Index2Tuple, - α::Number, - β::Number, - backend::Algorithm, - allocator, -) - bipermA = _blockedpermutation(pA) - bipermB = _blockedpermutation(pB) - bipermAB = _blockedpermutation(pAB) - A′ = conjA ? conj(A) : A - B′ = conjB ? conj(B) : B - return TensorAlgebra.contractadd!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) + C::AbstractArray, + A::AbstractArray, + pA::Index2Tuple, + conjA::Bool, + B::AbstractArray, + pB::Index2Tuple, + conjB::Bool, + pAB::Index2Tuple, + α::Number, + β::Number, + backend::Algorithm, + allocator, + ) + bipermA = _blockedpermutation(pA) + bipermB = _blockedpermutation(pB) + bipermAB = _blockedpermutation(pAB) + A′ = conjA ? conj(A) : A + B′ = conjB ? conj(B) : B + return TensorAlgebra.contractadd!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) end # For now no trace/add is supported, so simply reselect default backend from TensorOperations function TensorOperations.tensortrace!( - C::AbstractArray, - A::AbstractArray, - p::Index2Tuple, - q::Index2Tuple, - conjA::Bool, - α::Number, - β::Number, - ::Algorithm, - allocator, -) - return TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, DefaultBackend(), allocator) + C::AbstractArray, + A::AbstractArray, + p::Index2Tuple, + q::Index2Tuple, + conjA::Bool, + α::Number, + β::Number, + ::Algorithm, + allocator, + ) + return TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, DefaultBackend(), allocator) end function TensorOperations.tensoradd!( - C::AbstractArray, - A::AbstractArray, - pA::Index2Tuple, - conjA::Bool, - α::Number, - β::Number, - ::Algorithm, - allocator, -) - return TensorOperations.tensoradd!(C, A, pA, conjA, α, β, DefaultBackend(), allocator) + C::AbstractArray, + A::AbstractArray, + pA::Index2Tuple, + conjA::Bool, + α::Number, + β::Number, + ::Algorithm, + allocator, + ) + return TensorOperations.tensoradd!(C, A, pA, conjA, α, β, DefaultBackend(), allocator) end end diff --git a/src/BaseExtensions/indexin.jl b/src/BaseExtensions/indexin.jl index c302823..b25e225 100644 --- a/src/BaseExtensions/indexin.jl +++ b/src/BaseExtensions/indexin.jl @@ -1,5 +1,5 @@ # `Base.indexin` doesn't handle tuples indexin(x, y::AbstractArray) = Base.indexin(x, y) indexin(x, y) = Base.indexin(x, collect(y)) -indexin(x::Tuple, y::AbstractArray) = Tuple{Vararg{Any,length(x)}}(Base.indexin(x, y)) -indexin(x::Tuple, y) = Tuple{Vararg{Any,length(x)}}(Base.indexin(x, collect(y))) +indexin(x::Tuple, y::AbstractArray) = Tuple{Vararg{Any, length(x)}}(Base.indexin(x, y)) +indexin(x::Tuple, y) = Tuple{Vararg{Any, length(x)}}(Base.indexin(x, collect(y))) diff --git a/src/BaseExtensions/permutedims.jl b/src/BaseExtensions/permutedims.jl index 19f8fd7..0f7ea38 100644 --- a/src/BaseExtensions/permutedims.jl +++ b/src/BaseExtensions/permutedims.jl @@ -2,20 +2,20 @@ # Fixed by https://github.com/JuliaLang/julia/pull/52623. # TODO remove once support for Julia 1.10 is dropped function _permutedims!( - a_dest::AbstractArray{<:Any,N}, a_src::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}} -) where {N} - permutedims!(a_dest, a_src, perm) - return a_dest + a_dest::AbstractArray{<:Any, N}, a_src::AbstractArray{<:Any, N}, perm::Tuple{Vararg{Int, N}} + ) where {N} + permutedims!(a_dest, a_src, perm) + return a_dest end function _permutedims!( - a_dest::AbstractArray{<:Any,0}, a_src::AbstractArray{<:Any,0}, perm::Tuple{} -) - a_dest[] = a_src[] - return a_dest + a_dest::AbstractArray{<:Any, 0}, a_src::AbstractArray{<:Any, 0}, perm::Tuple{} + ) + a_dest[] = a_src[] + return a_dest end -function _permutedims(a::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}}) where {N} - return permutedims(a, perm) +function _permutedims(a::AbstractArray{<:Any, N}, perm::Tuple{Vararg{Int, N}}) where {N} + return permutedims(a, perm) end -function _permutedims(a::AbstractArray{<:Any,0}, perm::Tuple{}) - return copy(a) +function _permutedims(a::AbstractArray{<:Any, 0}, perm::Tuple{}) + return copy(a) end diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 5eec67f..d3e3658 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -1,151 +1,151 @@ module MatrixAlgebra export eigen, - eigen!, - eigvals, - eigvals!, - factorize, - factorize!, - lq, - lq!, - orth, - orth!, - polar, - polar!, - qr, - qr!, - svd, - svd!, - svdvals, - svdvals! + eigen!, + eigvals, + eigvals!, + factorize, + factorize!, + lq, + lq!, + orth, + orth!, + polar, + polar!, + qr, + qr!, + svd, + svd!, + svdvals, + svdvals! using LinearAlgebra: LinearAlgebra, norm using MatrixAlgebraKit for (f, f_full, f_compact) in ( - (:qr, :qr_full, :qr_compact), - (:qr!, :qr_full!, :qr_compact!), - (:lq, :lq_full, :lq_compact), - (:lq!, :lq_full!, :lq_compact!), -) - @eval begin - function $f(A::AbstractMatrix; full::Bool=false, kwargs...) - f = full ? $f_full : $f_compact - return f(A; kwargs...) + (:qr, :qr_full, :qr_compact), + (:qr!, :qr_full!, :qr_compact!), + (:lq, :lq_full, :lq_compact), + (:lq!, :lq_full!, :lq_compact!), + ) + @eval begin + function $f(A::AbstractMatrix; full::Bool = false, kwargs...) + f = full ? $f_full : $f_compact + return f(A; kwargs...) + end end - end end for (eigen, eigh_full, eig_full, eigh_trunc, eig_trunc) in ( - (:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc), - (:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!), -) - @eval begin - function $eigen(A::AbstractMatrix; trunc=nothing, ishermitian=nothing, kwargs...) - ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) - return if !isnothing(trunc) - if ishermitian - $eigh_trunc(A; trunc, kwargs...) - else - $eig_trunc(A; trunc, kwargs...) - end - else - if ishermitian - $eigh_full(A; kwargs...) - else - $eig_full(A; kwargs...) + (:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc), + (:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!), + ) + @eval begin + function $eigen(A::AbstractMatrix; trunc = nothing, ishermitian = nothing, kwargs...) + ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) + return if !isnothing(trunc) + if ishermitian + $eigh_trunc(A; trunc, kwargs...) + else + $eig_trunc(A; trunc, kwargs...) + end + else + if ishermitian + $eigh_full(A; kwargs...) + else + $eig_full(A; kwargs...) + end + end end - end end - end end for (eigvals, eigh_vals, eig_vals) in ((:eigvals, :eigh_vals, :eig_vals), (:eigvals!, :eigh_vals!, :eig_vals!)) - @eval begin - function $eigvals(A::AbstractMatrix; ishermitian=nothing, kwargs...) - ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) - f = (ishermitian ? $eigh_vals : $eig_vals) - return f(A; kwargs...) + @eval begin + function $eigvals(A::AbstractMatrix; ishermitian = nothing, kwargs...) + ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) + f = (ishermitian ? $eigh_vals : $eig_vals) + return f(A; kwargs...) + end end - end end for (svd, svd_trunc, svd_full, svd_compact) in ( - (:svd, :svd_trunc, :svd_full, :svd_compact), - (:svd!, :svd_trunc!, :svd_full!, :svd_compact!), -) - @eval begin - function $svd(A::AbstractMatrix; full::Bool=false, trunc=nothing, kwargs...) - return if !isnothing(trunc) - @assert !full "Specified both full and truncation, currently not supported" - $svd_trunc(A; trunc, kwargs...) - else - (full ? $svd_full : $svd_compact)(A; kwargs...) - end + (:svd, :svd_trunc, :svd_full, :svd_compact), + (:svd!, :svd_trunc!, :svd_full!, :svd_compact!), + ) + @eval begin + function $svd(A::AbstractMatrix; full::Bool = false, trunc = nothing, kwargs...) + return if !isnothing(trunc) + @assert !full "Specified both full and truncation, currently not supported" + $svd_trunc(A; trunc, kwargs...) + else + (full ? $svd_full : $svd_compact)(A; kwargs...) + end + end end - end end for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!, :svd_vals!)) - @eval begin - function $svdvals(A::AbstractMatrix; ishermitian=nothing, kwargs...) - return $svd_vals(A; kwargs...) + @eval begin + function $svdvals(A::AbstractMatrix; ishermitian = nothing, kwargs...) + return $svd_vals(A; kwargs...) + end end - end end for (polar, left_polar, right_polar) in ((:polar, :left_polar, :right_polar), (:polar!, :left_polar!, :right_polar!)) - @eval begin - function $polar(A::AbstractMatrix; side=:left, kwargs...) - f = if side == :left - $left_polar - elseif side == :right - $right_polar - else - throw(ArgumentError("`side=$side` not supported.")) - end - return f(A; kwargs...) + @eval begin + function $polar(A::AbstractMatrix; side = :left, kwargs...) + f = if side == :left + $left_polar + elseif side == :right + $right_polar + else + throw(ArgumentError("`side=$side` not supported.")) + end + return f(A; kwargs...) + end end - end end for (orth, left_orth, right_orth) in ((:orth, :left_orth, :right_orth), (:orth!, :left_orth!, :right_orth!)) - @eval begin - function $orth(A::AbstractMatrix; side=:left, kwargs...) - f = if side == :left - $left_orth - elseif side == :right - $right_orth - else - throw(ArgumentError("`side=$side` not supported.")) - end - return f(A; kwargs...) + @eval begin + function $orth(A::AbstractMatrix; side = :left, kwargs...) + f = if side == :left + $left_orth + elseif side == :right + $right_orth + else + throw(ArgumentError("`side=$side` not supported.")) + end + return f(A; kwargs...) + end end - end end for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, :orth!)) - @eval begin - function $factorize(A::AbstractMatrix; orth=:left, kwargs...) - f = if orth in (:left, :right) - $orth_f - else - throw(ArgumentError("`orth=$orth` not supported.")) - end - return f(A; side=orth, kwargs...) + @eval begin + function $factorize(A::AbstractMatrix; orth = :left, kwargs...) + f = if orth in (:left, :right) + $orth_f + else + throw(ArgumentError("`orth=$orth` not supported.")) + end + return f(A; side = orth, kwargs...) + end end - end end using MatrixAlgebraKit: MatrixAlgebraKit, TruncationStrategy -struct TruncationDegenerate{Strategy<:TruncationStrategy,T<:Real} <: TruncationStrategy - strategy::Strategy - atol::T - rtol::T +struct TruncationDegenerate{Strategy <: TruncationStrategy, T <: Real} <: TruncationStrategy + strategy::Strategy + atol::T + rtol::T end """ @@ -165,33 +165,33 @@ also only truncates for now, so may not respect if a minimum dimension was requested in the strategy being wrapped. These restrictions may be lifted in the future or provided through a different truncation strategy. """ -function truncdegen(strategy::TruncationStrategy; atol::Real=0, rtol::Real=0) - return TruncationDegenerate(strategy, promote(atol, rtol)...) +function truncdegen(strategy::TruncationStrategy; atol::Real = 0, rtol::Real = 0) + return TruncationDegenerate(strategy, promote(atol, rtol)...) end using MatrixAlgebraKit: findtruncated function MatrixAlgebraKit.findtruncated( - values::AbstractVector, strategy::TruncationDegenerate -) - Base.require_one_based_indexing(values) - issorted(values; rev=true) || throw(ArgumentError("Values must be reverse sorted.")) - indices_collection = findtruncated(values, strategy.strategy) - indices = Base.OneTo(maximum(indices_collection)) - indices_collection == indices || - throw(ArgumentError("Truncation must be a contiguous range.")) - if length(indices_collection) == length(values) - # No truncation occurred. - return indices - end - # The largest truncated value. - truncval = values[last(indices) + 1] - # Tolerance of determining if a value is degenerate. - atol = max(strategy.atol, strategy.rtol * abs(truncval)) - for rank in reverse(indices) - ≈(values[rank], truncval; atol, rtol=0) || return Base.OneTo(rank) - end - return Base.OneTo(0) + values::AbstractVector, strategy::TruncationDegenerate + ) + Base.require_one_based_indexing(values) + issorted(values; rev = true) || throw(ArgumentError("Values must be reverse sorted.")) + indices_collection = findtruncated(values, strategy.strategy) + indices = Base.OneTo(maximum(indices_collection)) + indices_collection == indices || + throw(ArgumentError("Truncation must be a contiguous range.")) + if length(indices_collection) == length(values) + # No truncation occurred. + return indices + end + # The largest truncated value. + truncval = values[last(indices) + 1] + # Tolerance of determining if a value is degenerate. + atol = max(strategy.atol, strategy.rtol * abs(truncval)) + for rank in reverse(indices) + ≈(values[rank], truncval; atol, rtol = 0) || return Base.OneTo(rank) + end + return Base.OneTo(0) end end diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 125f9b7..b848c1d 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,22 +1,22 @@ module TensorAlgebra export contract, - contract!, - eigen, - eigvals, - factorize, - left_null, - left_orth, - left_polar, - lq, - qr, - right_null, - right_orth, - right_polar, - orth, - polar, - svd, - svdvals + contract!, + eigen, + eigvals, + factorize, + left_null, + left_orth, + left_polar, + lq, + qr, + right_null, + right_orth, + right_polar, + orth, + polar, + svd, + svdvals include("MatrixAlgebra.jl") include("blockedtuple.jl") diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index c383587..03a1a69 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -1,18 +1,18 @@ using BlockArrays: - BlockArrays, Block, blockfirsts, blocklasts, blocklength, blocklengths, blocks + BlockArrays, Block, blockfirsts, blocklasts, blocklength, blocklengths, blocks using EllipsisNotation: Ellipsis, var".." using TupleTools: TupleTools trivialperm(len) = ntuple(identity, len) function istrivialperm(t::Tuple) - return t == trivialperm(length(t)) + return t == trivialperm(length(t)) end value(::Val{N}) where {N} = N _flatten_tuples(t::Tuple) = t function _flatten_tuples(t1::Tuple, t2::Tuple, trest::Tuple...) - return _flatten_tuples((t1..., t2...), trest...) + return _flatten_tuples((t1..., t2...), trest...) end _flatten_tuples() = () flatten_tuples(ts::Tuple) = _flatten_tuples(ts...) @@ -33,16 +33,16 @@ widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple # TODO: Optimize with StaticNumbers.jl or generated functions, see: # https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567 function blockedperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}}) - return blockedperm(BlockedTuple(perm, blocklengths)) + return blockedperm(BlockedTuple(perm, blocklengths)) end function blockedperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val) - return blockedperm(BlockedTuple(perm, BlockLengths)) + return blockedperm(BlockedTuple(perm, BlockLengths)) end function Base.invperm(bp::AbstractBlockPermutation) - # use Val to preserve compile time info - return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp))) + # use Val to preserve compile time info + return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp))) end # interface @@ -51,7 +51,7 @@ end # bipartitioned permutation. # Like `Base.permute!` block out-of-place and blocked. function blockpermute(v, blockedperm::AbstractBlockPermutation) - return tuplemortar(map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))) + return tuplemortar(map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))) end Base.getindex(v, perm::AbstractBlockPermutation) = blockpermute(v, perm) @@ -61,43 +61,43 @@ Base.getindex(v, perm::AbstractBlockPermutation) = blockpermute(v, perm) # function blockedperm(bt::AbstractBlockTuple) - return permmortar(blocks(bt)) + return permmortar(blocks(bt)) end # blockedpermvcat((4, 3), (2, 1)) function blockedpermvcat( - permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing -) - return blockedpermvcat(length, permblocks...) + permblocks::Tuple{Vararg{Int}}...; length::Union{Val, Nothing} = nothing + ) + return blockedpermvcat(length, permblocks...) end function blockedpermvcat(::Nothing, permblocks::Tuple{Vararg{Int}}...) - return blockedpermvcat(Val(sum(length, permblocks; init=zero(Bool))), permblocks...) + return blockedpermvcat(Val(sum(length, permblocks; init = zero(Bool))), permblocks...) end # blockedpermvcat((3, 2), 1) == blockedpermvcat((3, 2), (1,)) -function blockedpermvcat(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...) - return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) +function blockedpermvcat(permblocks::Union{Tuple{Vararg{Int}}, Int}...; kwargs...) + return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) end function blockedpermvcat( - permblocks::Union{Tuple{Vararg{Int}},Tuple{Ellipsis},Int,Ellipsis}...; kwargs... -) - return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) + permblocks::Union{Tuple{Vararg{Int}}, Tuple{Ellipsis}, Int, Ellipsis}...; kwargs... + ) + return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) end function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...) - value(len) != sum(length.(permblocks); init=0) && - throw(ArgumentError("Invalid total length")) - return permmortar(Tuple(permblocks)) + value(len) != sum(length.(permblocks); init = 0) && + throw(ArgumentError("Invalid total length")) + return permmortar(Tuple(permblocks)) end function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) - return maximum(specified_perm) + return maximum(specified_perm) end function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}}) - return value(vallength) + return value(vallength) end # blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,)) @@ -105,34 +105,34 @@ end # blockedpermvcat((4, 3), (..,), 1) == blockedpermvcat((4, 3), (2,), (1,)) # blockedpermvcat((4, 3), (..,), 1; length=Val(5)) == blockedpermvcat((4, 3), (2, 5), (1,)) function blockedpermvcat( - permblocks::Union{Tuple{Vararg{Int}},Ellipsis,Tuple{Ellipsis}}...; - length::Union{Val,Nothing}=nothing, -) - # Check there is only one `Ellipsis`. - @assert isone(count(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks)) - specified_permblocks = filter(x -> !(x isa Union{Ellipsis,Tuple{Ellipsis}}), permblocks) - unspecified_dim = findfirst(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks) - specified_perm = flatten_tuples(specified_permblocks) - len = _blockedperm_length(length, specified_perm) - unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm) - ndims_unspecified = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible - insert = unspecified_dims( - permblocks[unspecified_dim], unspecified_dims_vec, ndims_unspecified - ) - permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert) - return blockedpermvcat(permblocks_specified...) + permblocks::Union{Tuple{Vararg{Int}}, Ellipsis, Tuple{Ellipsis}}...; + length::Union{Val, Nothing} = nothing, + ) + # Check there is only one `Ellipsis`. + @assert isone(count(x -> x isa Union{Ellipsis, Tuple{Ellipsis}}, permblocks)) + specified_permblocks = filter(x -> !(x isa Union{Ellipsis, Tuple{Ellipsis}}), permblocks) + unspecified_dim = findfirst(x -> x isa Union{Ellipsis, Tuple{Ellipsis}}, permblocks) + specified_perm = flatten_tuples(specified_permblocks) + len = _blockedperm_length(length, specified_perm) + unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm) + ndims_unspecified = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible + insert = unspecified_dims( + permblocks[unspecified_dim], unspecified_dims_vec, ndims_unspecified + ) + permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert) + return blockedpermvcat(permblocks_specified...) end function unspecified_dims(::Tuple{Ellipsis}, unspecified_dims_vec, ndims_unspecified::Val) - return (ntuple(i -> unspecified_dims_vec[i], ndims_unspecified),) + return (ntuple(i -> unspecified_dims_vec[i], ndims_unspecified),) end function unspecified_dims(::Ellipsis, unspecified_dims_vec, ndims_unspecified::Val) - return ntuple(i -> (unspecified_dims_vec[i],), ndims_unspecified) + return ntuple(i -> (unspecified_dims_vec[i],), ndims_unspecified) end # Version of `indexin` that outputs a `blockedperm`. function blockedperm_indexin(collection, subs...) - return blockedpermvcat(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) + return blockedpermvcat(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) end # @@ -140,20 +140,20 @@ end # # for dispatch reason, it is convenient to have BlockLength as the first parameter -struct BlockedPermutation{BlockLength,BlockLengths,Flat} <: - AbstractBlockPermutation{BlockLength} - flat::Flat - - function BlockedPermutation{BlockLength,BlockLengths}( - flat::Tuple - ) where {BlockLength,BlockLengths} - length(flat) != sum(BlockLengths; init=0) && - throw(DimensionMismatch("Invalid total length")) - length(BlockLengths) != BlockLength && - throw(DimensionMismatch("Invalid total blocklength")) - any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length")) - return new{BlockLength,BlockLengths,typeof(flat)}(flat) - end +struct BlockedPermutation{BlockLength, BlockLengths, Flat} <: + AbstractBlockPermutation{BlockLength} + flat::Flat + + function BlockedPermutation{BlockLength, BlockLengths}( + flat::Tuple + ) where {BlockLength, BlockLengths} + length(flat) != sum(BlockLengths; init = 0) && + throw(DimensionMismatch("Invalid total length")) + length(BlockLengths) != BlockLength && + throw(DimensionMismatch("Invalid total blocklength")) + any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length")) + return new{BlockLength, BlockLengths, typeof(flat)}(flat) + end end # Base interface @@ -161,43 +161,43 @@ Base.Tuple(blockedperm::BlockedPermutation) = getfield(blockedperm, :flat) # BlockArrays interface function BlockArrays.blocklengths( - ::Type{<:BlockedPermutation{<:Any,BlockLengths}} -) where {BlockLengths} - return BlockLengths + ::Type{<:BlockedPermutation{<:Any, BlockLengths}} + ) where {BlockLengths} + return BlockLengths end function permmortar(permblocks::Tuple{Vararg{Tuple{Vararg{Int}}}}) - blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}( - flatten_tuples(permblocks) - ) - @assert isperm(blockedperm) - return blockedperm + blockedperm = BlockedPermutation{length(permblocks), length.(permblocks)}( + flatten_tuples(permblocks) + ) + @assert isperm(blockedperm) + return blockedperm end # # ============================== BlockedTrivialPermutation =============================== # -trivialperm(length::Union{Integer,Val}) = ntuple(identity, length) +trivialperm(length::Union{Integer, Val}) = ntuple(identity, length) -struct BlockedTrivialPermutation{BlockLength,BlockLengths} <: - AbstractBlockPermutation{BlockLength} end +struct BlockedTrivialPermutation{BlockLength, BlockLengths} <: + AbstractBlockPermutation{BlockLength} end Base.Tuple(blockedperm::BlockedTrivialPermutation) = trivialperm(length(blockedperm)) # BlockArrays interface function BlockArrays.blocklengths( - ::Type{<:BlockedTrivialPermutation{<:Any,BlockLengths}} -) where {BlockLengths} - return BlockLengths + ::Type{<:BlockedTrivialPermutation{<:Any, BlockLengths}} + ) where {BlockLengths} + return BlockLengths end blockedperm(tp::BlockedTrivialPermutation) = tp function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}}) - return BlockedTrivialPermutation{length(blocklengths),blocklengths}() + return BlockedTrivialPermutation{length(blocklengths), blocklengths}() end function trivialperm(blockedperm::AbstractBlockTuple) - return blockedtrivialperm(blocklengths(blockedperm)) + return blockedtrivialperm(blocklengths(blockedperm)) end Base.invperm(blockedperm::BlockedTrivialPermutation) = blockedperm diff --git a/src/blockedtuple.jl b/src/blockedtuple.jl index 30db0bd..02fac48 100644 --- a/src/blockedtuple.jl +++ b/src/blockedtuple.jl @@ -18,7 +18,7 @@ widened_constructorof(type::Type{<:AbstractBlockTuple}) = constructorof(type) # Like `BlockRange`. function blockeachindex(bt::AbstractBlockTuple) - return ntuple(i -> Block(i), blocklength(bt)) + return ntuple(i -> Block(i), blocklength(bt)) end # Base interface @@ -33,12 +33,12 @@ Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i] Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r] Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)] function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1}) - r = Int.(br) - flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]] - return widened_constructorof(typeof(bt))(flat, blocklengths(bt)[r]) + r = Int.(br) + flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]] + return widened_constructorof(typeof(bt))(flat, blocklengths(bt)[r]) end function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1}) - return bt[Block(bi)][only(bi.indices)] + return bt[Block(bi)][only(bi.indices)] end # needed for nested broadcast in Julia < 1.11 Base.getindex(bt::AbstractBlockTuple, ci::CartesianIndex{1}) = bt[only(Tuple(ci))] @@ -48,27 +48,27 @@ Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i) Base.lastindex(bt::AbstractBlockTuple) = length(bt) -Base.length(bt::AbstractBlockTuple) = sum(blocklengths(bt); init=0) +Base.length(bt::AbstractBlockTuple) = sum(blocklengths(bt); init = 0) function Base.map(f, bt::AbstractBlockTuple) - BL = blocklengths(bt) - # use Val to preserve compile time knowledge of BL - return widened_constructorof(typeof(bt))(map(f, Tuple(bt)), Val(BL)) + BL = blocklengths(bt) + # use Val to preserve compile time knowledge of BL + return widened_constructorof(typeof(bt))(map(f, Tuple(bt)), Val(BL)) end function Base.show(io::IO, bt::AbstractBlockTuple) - return print(io, nameof(typeof(bt)), blocks(bt)) + return print(io, nameof(typeof(bt)), blocks(bt)) end function Base.show(io::IO, ::MIME"text/plain", bt::AbstractBlockTuple) - println(io, typeof(bt)) - return print(io, blocks(bt)) + println(io, typeof(bt)) + return print(io, blocks(bt)) end # Broadcast interface Base.broadcastable(bt::AbstractBlockTuple) = bt -struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end +struct AbstractBlockTupleBroadcastStyle{BlockLengths, BT} <: Broadcast.BroadcastStyle end function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple}) - return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}() + return AbstractBlockTupleBroadcastStyle{blocklengths(T), unspecify_type_parameters(T)}() end # default @@ -79,58 +79,58 @@ combine_types(::Type{<:AbstractBlockTuple}, ::Type{<:AbstractBlockTuple}) = Bloc # tuplemortar(((1,), (2,))) .== tuplemortar(((1, 2),)) = tuplemortar(((true,), (true,))) # tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,), (3,))) = error DimensionMismatch function Base.BroadcastStyle( - s1::AbstractBlockTupleBroadcastStyle, s2::AbstractBlockTupleBroadcastStyle -) - blocklengths1 = type_parameters(s1, 1) - blocklengths2 = type_parameters(s2, 1) - sum(blocklengths1; init=0) != sum(blocklengths2; init=0) && - throw(DimensionMismatch("blocked tuples could not be broadcast to a common size")) - new_blocklasts = static_mergesort(cumsum(blocklengths1), cumsum(blocklengths2)) - new_blocklengths = ( - first(new_blocklasts), Base.tail(new_blocklasts) .- Base.front(new_blocklasts)... - ) - BT = combine_types(type_parameters(s1, 2), type_parameters(s2, 2)) - return AbstractBlockTupleBroadcastStyle{new_blocklengths,BT}() + s1::AbstractBlockTupleBroadcastStyle, s2::AbstractBlockTupleBroadcastStyle + ) + blocklengths1 = type_parameters(s1, 1) + blocklengths2 = type_parameters(s2, 1) + sum(blocklengths1; init = 0) != sum(blocklengths2; init = 0) && + throw(DimensionMismatch("blocked tuples could not be broadcast to a common size")) + new_blocklasts = static_mergesort(cumsum(blocklengths1), cumsum(blocklengths2)) + new_blocklengths = ( + first(new_blocklasts), Base.tail(new_blocklasts) .- Base.front(new_blocklasts)..., + ) + BT = combine_types(type_parameters(s1, 2), type_parameters(s2, 2)) + return AbstractBlockTupleBroadcastStyle{new_blocklengths, BT}() end static_mergesort(::Tuple{}, ::Tuple{}) = () static_mergesort(a::Tuple, ::Tuple{}) = a static_mergesort(::Tuple{}, b::Tuple) = b function static_mergesort(a::Tuple, b::Tuple) - if first(a) == first(b) - return (first(a), static_mergesort(Base.tail(a), Base.tail(b))...) - end - if first(a) < first(b) - return (first(a), static_mergesort(Base.tail(a), b)...) - end - return (first(b), static_mergesort(a, Base.tail(b))...) + if first(a) == first(b) + return (first(a), static_mergesort(Base.tail(a), Base.tail(b))...) + end + if first(a) < first(b) + return (first(a), static_mergesort(Base.tail(a), b)...) + end + return (first(b), static_mergesort(a, Base.tail(b))...) end # tuplemortar(((1,), (2,))) .== (1, 2) = (true, true) function Base.BroadcastStyle( - s::AbstractBlockTupleBroadcastStyle, ::Base.Broadcast.Style{Tuple} -) - return s + s::AbstractBlockTupleBroadcastStyle, ::Base.Broadcast.Style{Tuple} + ) + return s end # tuplemortar(((1,), (2,))) .== 1 = (true, false) function Base.BroadcastStyle( - ::Base.Broadcast.DefaultArrayStyle{0}, s::AbstractBlockTupleBroadcastStyle -) - return s + ::Base.Broadcast.DefaultArrayStyle{0}, s::AbstractBlockTupleBroadcastStyle + ) + return s end # tuplemortar(((1,), (2,))) .== [1, 1] = BlockVector([true, false], [1, 1]) function Base.BroadcastStyle( - a::Base.Broadcast.AbstractArrayStyle, ::AbstractBlockTupleBroadcastStyle -) - return a + a::Base.Broadcast.AbstractArrayStyle, ::AbstractBlockTupleBroadcastStyle + ) + return a end function Base.copy( - bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}} -) where {BlockLengths,BT} - return widened_constructorof(BT)(bc.f.((Tuple.(bc.args))...), Val(BlockLengths)) + bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths, BT}} + ) where {BlockLengths, BT} + return widened_constructorof(BT)(bc.f.((Tuple.(bc.args))...), Val(BlockLengths)) end Base.ndims(::Type{<:AbstractBlockTuple}) = 1 # needed in nested broadcast @@ -138,11 +138,11 @@ Base.ndims(::Type{<:AbstractBlockTuple}) = 1 # needed in nested broadcast # BlockArrays interface BlockArrays.blockfirsts(::AbstractBlockTuple{0}) = () function BlockArrays.blockfirsts(bt::AbstractBlockTuple) - return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1 + return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1 end function BlockArrays.blocklasts(bt::AbstractBlockTuple) - return cumsum(blocklengths(bt)) + return cumsum(blocklengths(bt)) end BlockArrays.blocklength(::AbstractBlockTuple{BlockLength}) where {BlockLength} = BlockLength @@ -150,41 +150,41 @@ BlockArrays.blocklength(::AbstractBlockTuple{BlockLength}) where {BlockLength} = BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt)) function BlockArrays.blocks(bt::AbstractBlockTuple) - bf = blockfirsts(bt) - bl = blocklasts(bt) - return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt)) + bf = blockfirsts(bt) + bl = blocklasts(bt) + return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt)) end # ===================================== BlockedTuple ===================================== # -struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength} - flat::Flat +struct BlockedTuple{BlockLength, BlockLengths, Flat} <: AbstractBlockTuple{BlockLength} + flat::Flat - function BlockedTuple{BlockLength,BlockLengths}( - flat::Tuple - ) where {BlockLength,BlockLengths} - length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength")) - length(flat) != sum(BlockLengths; init=0) && - throw(DimensionMismatch("Invalid total length")) - any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length")) - return new{BlockLength,BlockLengths,typeof(flat)}(flat) - end + function BlockedTuple{BlockLength, BlockLengths}( + flat::Tuple + ) where {BlockLength, BlockLengths} + length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength")) + length(flat) != sum(BlockLengths; init = 0) && + throw(DimensionMismatch("Invalid total length")) + any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length")) + return new{BlockLength, BlockLengths, typeof(flat)}(flat) + end end # TensorAlgebra Interface function tuplemortar(tt::Tuple{Vararg{Tuple}}) - return BlockedTuple{length(tt),length.(tt)}(flatten_tuples(tt)) + return BlockedTuple{length(tt), length.(tt)}(flatten_tuples(tt)) end function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}}) - return BlockedTuple{length(BlockLengths),BlockLengths}(flat) + return BlockedTuple{length(BlockLengths), BlockLengths}(flat) end function BlockedTuple(flat::Tuple, ::Val{BlockLengths}) where {BlockLengths} - # use Val to preserve compile time knowledge of BL - return BlockedTuple{length(BlockLengths),BlockLengths}(flat) + # use Val to preserve compile time knowledge of BL + return BlockedTuple{length(BlockLengths), BlockLengths}(flat) end function BlockedTuple(bt::AbstractBlockTuple) - bl = blocklengths(bt) - return BlockedTuple{length(bl),bl}(Tuple(bt)) + bl = blocklengths(bt) + return BlockedTuple{length(bl), bl}(Tuple(bt)) end # Base interface @@ -192,7 +192,7 @@ Base.Tuple(bt::BlockedTuple) = bt.flat # BlockArrays interface function BlockArrays.blocklengths( - ::Type{<:BlockedTuple{<:Any,BlockLengths}} -) where {BlockLengths} - return BlockLengths + ::Type{<:BlockedTuple{<:Any, BlockLengths}} + ) where {BlockLengths} + return BlockLengths end diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index f1b311d..53964f6 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -1,46 +1,46 @@ using Base.PermutedDimsArrays: genperm function check_input(::typeof(contract), a1, labels1, a2, labels2) - ndims(a1) == length(labels1) || - throw(ArgumentError("Invalid permutation for left tensor")) - return ndims(a2) == length(labels2) || - throw(ArgumentError("Invalid permutation for right tensor")) + ndims(a1) == length(labels1) || + throw(ArgumentError("Invalid permutation for left tensor")) + return ndims(a2) == length(labels2) || + throw(ArgumentError("Invalid permutation for right tensor")) end function check_input(::typeof(contract), a_dest, labels_dest, a1, labels1, a2, labels2) - ndims(a_dest) == length(labels_dest) || - throw(ArgumentError("Invalid permutation for destination tensor")) - return check_input(contract, a1, labels1, a2, labels2) + ndims(a_dest) == length(labels_dest) || + throw(ArgumentError("Invalid permutation for destination tensor")) + return check_input(contract, a1, labels1, a2, labels2) end # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function output_axes( - ::typeof(contract), - biperm_dest::AbstractBlockPermutation{2}, - a1::AbstractArray, - biperm1::AbstractBlockPermutation{2}, - a2::AbstractArray, - biperm2::AbstractBlockPermutation{2}, -) - axes_codomain, axes_contracted = blocks(axes(a1)[biperm1]) - axes_contracted2, axes_domain = blocks(axes(a2)[biperm2]) - @assert length.(axes_contracted) == length.(axes_contracted2) - # default: flatten biperm_out - return genperm((axes_codomain..., axes_domain...), Tuple(biperm_dest)) + ::typeof(contract), + biperm_dest::AbstractBlockPermutation{2}, + a1::AbstractArray, + biperm1::AbstractBlockPermutation{2}, + a2::AbstractArray, + biperm2::AbstractBlockPermutation{2}, + ) + axes_codomain, axes_contracted = blocks(axes(a1)[biperm1]) + axes_contracted2, axes_domain = blocks(axes(a2)[biperm2]) + @assert length.(axes_contracted) == length.(axes_contracted2) + # default: flatten biperm_out + return genperm((axes_codomain..., axes_domain...), Tuple(biperm_dest)) end # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( - ::typeof(contract), - biperm_dest::AbstractBlockPermutation, - a1::AbstractArray, - biperm1::AbstractBlockPermutation, - a2::AbstractArray, - biperm2::AbstractBlockPermutation, -) - check_input(contract, a1, biperm1, a2, biperm2) - axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2) - return similar(a1, promote_type(eltype(a1), eltype(a2)), axes_dest) + ::typeof(contract), + biperm_dest::AbstractBlockPermutation, + a1::AbstractArray, + biperm1::AbstractBlockPermutation, + a2::AbstractArray, + biperm2::AbstractBlockPermutation, + ) + check_input(contract, a1, biperm1, a2, biperm2) + axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2) + return similar(a1, promote_type(eltype(a1), eltype(a2)), axes_dest) end diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index ad1ee21..790c1d9 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -3,11 +3,11 @@ using BlockArrays: blocklengths # default: if no bipartion is specified, all axes to domain function biperm(perm, blocklength1::Integer) - return biperm(perm, Val(blocklength1)) + return biperm(perm, Val(blocklength1)) end function biperm(perm, ::Val{BlockLength1}) where {BlockLength1} - length(perm) < BlockLength1 && throw(ArgumentError("Invalid codomain length")) - return blockedperm(Tuple(perm), (BlockLength1, length(perm) - BlockLength1)) + length(perm) < BlockLength1 && throw(ArgumentError("Invalid codomain length")) + return blockedperm(Tuple(perm), (BlockLength1, length(perm) - BlockLength1)) end length_domain(t::AbstractBlockTuple{2}) = last(blocklengths(t)) @@ -17,36 +17,36 @@ length_domain(t) = 0 length_codomain(t) = length(t) - length_domain(t) function blockedperms( - f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2 -) - return blockedperms(f, dimnames_dest, dimnames1, dimnames2) + f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2 + ) + return blockedperms(f, dimnames_dest, dimnames1, dimnames2) end # codomain <-- domain function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) - dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2))) - for i in unique(dimnames) - count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels")) - end - - codomain = Tuple(setdiff(dimnames1, dimnames2)) - contracted = Tuple(intersect(dimnames1, dimnames2)) - domain = Tuple(setdiff(dimnames2, dimnames1)) - - perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest) - perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest) - invbiperm = (perm_codomain_dest..., perm_domain_dest...) - biperm_dest = biperm(invperm(invbiperm), length_codomain(dimnames_dest)) - - perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1) - perm_domain1 = BaseExtensions.indexin(contracted, dimnames1) - - perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2) - perm_domain2 = BaseExtensions.indexin(domain, dimnames2) - - permblocks1 = (perm_codomain1, perm_domain1) - biperm1 = blockedpermvcat(permblocks1...) - permblocks2 = (perm_codomain2, perm_domain2) - biperm2 = blockedpermvcat(permblocks2...) - return biperm_dest, biperm1, biperm2 + dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2))) + for i in unique(dimnames) + count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels")) + end + + codomain = Tuple(setdiff(dimnames1, dimnames2)) + contracted = Tuple(intersect(dimnames1, dimnames2)) + domain = Tuple(setdiff(dimnames2, dimnames1)) + + perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest) + perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest) + invbiperm = (perm_codomain_dest..., perm_domain_dest...) + biperm_dest = biperm(invperm(invbiperm), length_codomain(dimnames_dest)) + + perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1) + perm_domain1 = BaseExtensions.indexin(contracted, dimnames1) + + perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2) + perm_domain2 = BaseExtensions.indexin(domain, dimnames2) + + permblocks1 = (perm_codomain1, perm_domain1) + biperm1 = blockedpermvcat(permblocks1...) + permblocks2 = (perm_codomain2, perm_domain2) + biperm2 = blockedpermvcat(permblocks2...) + return biperm_dest, biperm1, biperm2 end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index abdcd85..5d6fcf6 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -11,136 +11,136 @@ default_contract_alg() = Matricize() # Required interface if not using # matricized contraction. function contractadd!( - alg::Algorithm, - a_dest::AbstractArray, - biperm_dest::AbstractBlockPermutation, - a1::AbstractArray, - biperm1::AbstractBlockPermutation, - a2::AbstractArray, - biperm2::AbstractBlockPermutation, - α::Number, - β::Number, -) - return error("Not implemented") + alg::Algorithm, + a_dest::AbstractArray, + biperm_dest::AbstractBlockPermutation, + a1::AbstractArray, + biperm1::AbstractBlockPermutation, + a2::AbstractArray, + biperm2::AbstractBlockPermutation, + α::Number, + β::Number, + ) + return error("Not implemented") end function contract( - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; - alg=default_contract_alg(), - kwargs..., -) - return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...) + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2; + alg = default_contract_alg(), + kwargs..., + ) + return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...) end function contract( - alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs... -) - labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2; kwargs...) - return contract(alg, labels_dest, a1, labels1, a2, labels2; kwargs...), labels_dest + alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs... + ) + labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2; kwargs...) + return contract(alg, labels_dest, a1, labels1, a2, labels2; kwargs...), labels_dest end function contract( - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; - alg=default_contract_alg(), - kwargs..., -) - return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...) + labels_dest, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2; + alg = default_contract_alg(), + kwargs..., + ) + return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...) end function contract!( - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; - kwargs..., -) - return contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) + a_dest::AbstractArray, + labels_dest, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2; + kwargs..., + ) + return contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) end function contractadd!( - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2, - α::Number, - β::Number; - alg=default_contract_alg(), - kwargs..., -) - contractadd!( - Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs... - ) - return a_dest + a_dest::AbstractArray, + labels_dest, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2, + α::Number, + β::Number; + alg = default_contract_alg(), + kwargs..., + ) + contractadd!( + Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs... + ) + return a_dest end function contract( - alg::Algorithm, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; - kwargs..., -) - check_input(contract, a1, labels1, a2, labels2) - biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contract(alg, biperm_dest, a1, biperm1, a2, biperm2; kwargs...) + alg::Algorithm, + labels_dest, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2; + kwargs..., + ) + check_input(contract, a1, labels1, a2, labels2) + biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) + return contract(alg, biperm_dest, a1, biperm1, a2, biperm2; kwargs...) end function contract!( - alg::Algorithm, - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; - kwargs..., -) - return contractadd!( - alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs... - ) + alg::Algorithm, + a_dest::AbstractArray, + labels_dest, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2; + kwargs..., + ) + return contractadd!( + alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs... + ) end function contractadd!( - alg::Algorithm, - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2, - α::Number, - β::Number; - kwargs..., -) - check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2) - biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contractadd!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) + alg::Algorithm, + a_dest::AbstractArray, + labels_dest, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2, + α::Number, + β::Number; + kwargs..., + ) + check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2) + biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) + return contractadd!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) end function contract( - alg::Algorithm, - biperm_dest::AbstractBlockPermutation, - a1::AbstractArray, - biperm1::AbstractBlockPermutation, - a2::AbstractArray, - biperm2::AbstractBlockPermutation; - kwargs..., -) - check_input(contract, a1, biperm1, a2, biperm2) - a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2) - contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2; kwargs...) - return a_dest + alg::Algorithm, + biperm_dest::AbstractBlockPermutation, + a1::AbstractArray, + biperm1::AbstractBlockPermutation, + a2::AbstractArray, + biperm2::AbstractBlockPermutation; + kwargs..., + ) + check_input(contract, a1, biperm1, a2, biperm2) + a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2) + contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2; kwargs...) + return a_dest end diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index be4f26c..f7207ee 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -1,22 +1,22 @@ using LinearAlgebra: mul! function contractadd!( - ::Matricize, - a_dest::AbstractArray, - biperm_dest::AbstractBlockPermutation{2}, - a1::AbstractArray, - biperm1::AbstractBlockPermutation{2}, - a2::AbstractArray, - biperm2::AbstractBlockPermutation{2}, - α::Number, - β::Number, -) - invbiperm = biperm(invperm(biperm_dest), length_codomain(biperm1)) + ::Matricize, + a_dest::AbstractArray, + biperm_dest::AbstractBlockPermutation{2}, + a1::AbstractArray, + biperm1::AbstractBlockPermutation{2}, + a2::AbstractArray, + biperm2::AbstractBlockPermutation{2}, + α::Number, + β::Number, + ) + invbiperm = biperm(invperm(biperm_dest), length_codomain(biperm1)) - check_input(contract, a_dest, invbiperm, a1, biperm1, a2, biperm2) - a1_mat = matricize(a1, biperm1) - a2_mat = matricize(a2, biperm2) - a_dest_mat = a1_mat * a2_mat - unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β) - return a_dest + check_input(contract, a_dest, invbiperm, a1, biperm1, a2, biperm2) + a1_mat = matricize(a1, biperm1) + a2_mat = matricize(a2, biperm2) + a_dest_mat = a1_mat * a2_mat + unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β) + return a_dest end diff --git a/src/contract/output_labels.jl b/src/contract/output_labels.jl index c9c18bf..071869d 100644 --- a/src/contract/output_labels.jl +++ b/src/contract/output_labels.jl @@ -1,20 +1,20 @@ function output_labels( - f::typeof(contract), - alg::Algorithm, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2, -) - return output_labels(f, alg, labels1, labels2) + f::typeof(contract), + alg::Algorithm, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2, + ) + return output_labels(f, alg, labels1, labels2) end function output_labels(f::typeof(contract), ::Algorithm, labels1, labels2) - return output_labels(f, labels1, labels2) + return output_labels(f, labels1, labels2) end function output_labels(::typeof(contract), labels1, labels2) - diff1 = Tuple(setdiff(labels1, labels2)) - diff2 = Tuple(setdiff(labels2, labels1)) - return tuplemortar((diff1, diff2)) + diff1 = Tuple(setdiff(labels1, labels2)) + diff2 = Tuple(setdiff(labels2, labels1)) + return tuplemortar((diff1, diff2)) end diff --git a/src/factorizations.jl b/src/factorizations.jl index 6312564..1f05470 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -2,27 +2,27 @@ using LinearAlgebra: LinearAlgebra using MatrixAlgebraKit: MatrixAlgebraKit for f in ( - :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize -) - @eval begin - function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return $f(A, biperm; kwargs...) + :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize, + ) + @eval begin + function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return $f(A, biperm; kwargs...) + end + function $f(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) + # tensor to matrix + A_mat = matricize(A, biperm) + + # factorization + X, Y = MatrixAlgebra.$f(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) + axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) + return unmatricize(X, axes_X), unmatricize(Y, axes_Y) + end end - function $f(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - # tensor to matrix - A_mat = matricize(A, biperm) - - # factorization - X, Y = MatrixAlgebra.$f(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blocks(axes(A)[biperm]) - axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) - axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) - return unmatricize(X, axes_X), unmatricize(Y, axes_Y) - end - end end """ @@ -160,20 +160,20 @@ See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlg `MatrixAlgebraKit.eigh_full!`, `MatrixAlgebraKit.eigh_trunc!`, and `MatrixAlgebraKit.eigh_vals!`. """ function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return eigen(A, biperm; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return eigen(A, biperm; kwargs...) end function eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - # tensor to matrix - A_mat = matricize(A, biperm) + # tensor to matrix + A_mat = matricize(A, biperm) - # factorization - D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) + # factorization + D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) - # matrix to tensor - axes_codomain, = blocks(axes(A)[biperm]) - axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) - return D, unmatricize(V, axes_V) + # matrix to tensor + axes_codomain, = blocks(axes(A)[biperm]) + axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) + return D, unmatricize(V, axes_V) end """ @@ -193,12 +193,12 @@ their labels, or directly through a `biperm`. The output is a vector of eigenval See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`. """ function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return eigvals(A, biperm; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return eigvals(A, biperm; kwargs...) end function eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - A_mat = matricize(A, biperm) - return MatrixAlgebra.eigvals!(A_mat; kwargs...) + A_mat = matricize(A, biperm) + return MatrixAlgebra.eigvals!(A_mat; kwargs...) end """ @@ -219,21 +219,21 @@ their labels, or directly through a `biperm`. See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `MatrixAlgebraKit.svd_trunc!`. """ function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return svd(A, biperm; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return svd(A, biperm; kwargs...) end function svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - # tensor to matrix - A_mat = matricize(A, biperm) + # tensor to matrix + A_mat = matricize(A, biperm) - # factorization - U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) + # factorization + U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) - # matrix to tensor - axes_codomain, axes_domain = blocks(axes(A)[biperm]) - axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) - axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) - return unmatricize(U, axes_U), S, unmatricize(Vᴴ, axes_Vᴴ) + # matrix to tensor + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) + axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) + return unmatricize(U, axes_U), S, unmatricize(Vᴴ, axes_Vᴴ) end """ @@ -247,12 +247,12 @@ their labels, or directly through a `biperm`. The output is a vector of singular See also `MatrixAlgebraKit.svd_vals!`. """ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return svdvals(A, biperm) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return svdvals(A, biperm) end function svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) - A_mat = matricize(A, biperm) - return MatrixAlgebra.svdvals!(A_mat) + A_mat = matricize(A, biperm) + return MatrixAlgebra.svdvals!(A_mat) end """ @@ -273,15 +273,15 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`. The default is `:qrpos` if `atol == rtol == 0`, and `:svd` otherwise. """ function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return left_null(A, biperm; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return left_null(A, biperm; kwargs...) end function left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - A_mat = matricize(A, biperm) - N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) - axes_codomain = first(blocks(axes(A)[biperm])) - axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) - return unmatricize(N, axes_N) + A_mat = matricize(A, biperm) + N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) + axes_codomain = first(blocks(axes(A)[biperm])) + axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) + return unmatricize(N, axes_N) end """ @@ -302,13 +302,13 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. The default is `:lqpos` if `atol == rtol == 0`, and `:svd` otherwise. """ function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return right_null(A, biperm; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return right_null(A, biperm; kwargs...) end function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - A_mat = matricize(A, biperm) - Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) - axes_domain = last(blocks((axes(A)[biperm]))) - axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) - return unmatricize(Nᴴ, axes_Nᴴ) + A_mat = matricize(A, biperm) + Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) + axes_domain = last(blocks((axes(A)[biperm]))) + axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) + return unmatricize(Nᴴ, axes_Nᴴ) end diff --git a/src/matricize.jl b/src/matricize.jl index d812d92..705d303 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -22,22 +22,22 @@ trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1) trivial_axis(::Tuple{Vararg{AbstractBlockedUnitRange}}) = blockedrange([1]) function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation -) - axesblocks = blocks(axes[blockedperm]) - return map(block -> isempty(block) ? trivial_axis(axes) : ⊗(block...), axesblocks) + axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation + ) + axesblocks = blocks(axes[blockedperm]) + return map(block -> isempty(block) ? trivial_axis(axes) : ⊗(block...), axesblocks) end # TODO remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation) - return _permutedims(a, Tuple(biperm)) + return _permutedims(a, Tuple(biperm)) end function permuteblockeddims!( - a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation -) - return _permutedims!(a, b, Tuple(biperm)) + a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation + ) + return _permutedims!(a, b, Tuple(biperm)) end # ===================================== matricize ======================================== @@ -46,84 +46,84 @@ end # maybe: copy=false kwarg function matricize(a::AbstractArray, biperm_dest::AbstractBlockPermutation{2}) - ndims(a) == length(biperm_dest) || throw(ArgumentError("Invalid bipermutation")) - return matricize(FusionStyle(a), a, biperm_dest) + ndims(a) == length(biperm_dest) || throw(ArgumentError("Invalid bipermutation")) + return matricize(FusionStyle(a), a, biperm_dest) end function matricize( - style::FusionStyle, a::AbstractArray, biperm_dest::AbstractBlockPermutation{2} -) - a_perm = permuteblockeddims(a, biperm_dest) - return matricize(style, a_perm, trivialperm(biperm_dest)) + style::FusionStyle, a::AbstractArray, biperm_dest::AbstractBlockPermutation{2} + ) + a_perm = permuteblockeddims(a, biperm_dest) + return matricize(style, a_perm, trivialperm(biperm_dest)) end function matricize( - style::FusionStyle, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2} -) - return throw(MethodError(matricize, Tuple{typeof(style),typeof(a),typeof(biperm_dest)})) + style::FusionStyle, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2} + ) + return throw(MethodError(matricize, Tuple{typeof(style), typeof(a), typeof(biperm_dest)})) end # default is reshape function matricize( - ::ReshapeFusion, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2} -) - new_axes = fuseaxes(axes(a), biperm_dest) - return reshape(a, new_axes...) + ::ReshapeFusion, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2} + ) + new_axes = fuseaxes(axes(a), biperm_dest) + return reshape(a, new_axes...) end function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple) - return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a)))) + return matricize(a, blockedpermvcat(permblock1, permblock2; length = Val(ndims(a)))) end # ==================================== unmatricize ======================================= function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}) - length(axes_dest) == length(invbiperm) || - throw(ArgumentError("axes do not match permutation")) - return unmatricize(FusionStyle(m), m, axes_dest, invbiperm) + length(axes_dest) == length(invbiperm) || + throw(ArgumentError("axes do not match permutation")) + return unmatricize(FusionStyle(m), m, axes_dest, invbiperm) end function unmatricize( - ::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2} -) - blocked_axes = axes_dest[invbiperm] - a12 = unmatricize(m, blocked_axes) - biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest)) + ::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2} + ) + blocked_axes = axes_dest[invbiperm] + a12 = unmatricize(m, blocked_axes) + biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest)) - return permuteblockeddims(a12, biperm_dest) + return permuteblockeddims(a12, biperm_dest) end function unmatricize( - ::ReshapeFusion, - m::AbstractMatrix, - blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}, -) - return reshape(m, Tuple(blocked_axes)...) + ::ReshapeFusion, + m::AbstractMatrix, + blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}}, + ) + return reshape(m, Tuple(blocked_axes)...) end function unmatricize(m::AbstractMatrix, blocked_axes) - return unmatricize(FusionStyle(m), m, blocked_axes) + return unmatricize(FusionStyle(m), m, blocked_axes) end function unmatricize( - m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, -) - blocked_axes = tuplemortar((codomain_axes, domain_axes)) - return unmatricize(m, blocked_axes) + m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ) + blocked_axes = tuplemortar((codomain_axes, domain_axes)) + return unmatricize(m, blocked_axes) end function unmatricize!(a_dest, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2}) - ndims(a_dest) == length(invbiperm) || - throw(ArgumentError("destination does not match permutation")) - blocked_axes = axes(a_dest)[invbiperm] - a_perm = unmatricize(m, blocked_axes) - biperm_dest = biperm(invperm(invbiperm), length_codomain(axes(a_dest))) - return permuteblockeddims!(a_dest, a_perm, biperm_dest) + ndims(a_dest) == length(invbiperm) || + throw(ArgumentError("destination does not match permutation")) + blocked_axes = axes(a_dest)[invbiperm] + a_perm = unmatricize(m, blocked_axes) + biperm_dest = biperm(invperm(invbiperm), length_codomain(axes(a_dest))) + return permuteblockeddims!(a_dest, a_perm, biperm_dest) end function unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β) - a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm) - a_dest .= α .* a12 .+ β .* a_dest - return a_dest + a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm) + a_dest .= α .* a12 .+ β .* a_dest + return a_dest end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index 871769e..d90ed15 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -1,46 +1,46 @@ # TensorAlgebra version of matrix functions. const MATRIX_FUNCTIONS = [ - :exp, - :cis, - :log, - :sqrt, - :cbrt, - :cos, - :sin, - :tan, - :csc, - :sec, - :cot, - :cosh, - :sinh, - :tanh, - :csch, - :sech, - :coth, - :acos, - :asin, - :atan, - :acsc, - :asec, - :acot, - :acosh, - :asinh, - :atanh, - :acsch, - :asech, - :acoth, + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, ] for f in MATRIX_FUNCTIONS - @eval begin - function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) - return $f(a, biperm; kwargs...) + @eval begin + function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) + return $f(a, biperm; kwargs...) + end + function $f(a::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) + a_mat = matricize(a, biperm) + fa_mat = Base.$f(a_mat; kwargs...) + return unmatricize(fa_mat, axes(a)[biperm]) + end end - function $f(a::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - a_mat = matricize(a, biperm) - fa_mat = Base.$f(a_mat; kwargs...) - return unmatricize(fa_mat, axes(a)[biperm]) - end - end end diff --git a/test/runtests.jl b/test/runtests.jl index 98b2d2b..0008050 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,60 +6,62 @@ using Suppressor: Suppressor const pat = r"(?:--group=)(\w+)" arg_id = findfirst(contains(pat), ARGS) const GROUP = uppercase( - if isnothing(arg_id) - get(ENV, "GROUP", "ALL") - else - only(match(pat, ARGS[arg_id]).captures) - end, + if isnothing(arg_id) + get(ENV, "GROUP", "ALL") + else + only(match(pat, ARGS[arg_id]).captures) + end, ) "match files of the form `test_*.jl`, but exclude `*setup*.jl`" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") end "match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" function isexamplefile(fn) - return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") + return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @time begin - # tests in groups based on folder structure - for testgroup in filter(isdir, readdir(@__DIR__)) - if GROUP == "ALL" || GROUP == uppercase(testgroup) - groupdir = joinpath(@__DIR__, testgroup) - for file in filter(istestfile, readdir(groupdir)) - filename = joinpath(groupdir, file) - @eval @safetestset $file begin - include($filename) + # tests in groups based on folder structure + for testgroup in filter(isdir, readdir(@__DIR__)) + if GROUP == "ALL" || GROUP == uppercase(testgroup) + groupdir = joinpath(@__DIR__, testgroup) + for file in filter(istestfile, readdir(groupdir)) + filename = joinpath(groupdir, file) + @eval @safetestset $file begin + include($filename) + end + end end - end end - end - # single files in top folder - for file in filter(istestfile, readdir(@__DIR__)) - (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion - @eval @safetestset $file begin - include($file) + # single files in top folder + for file in filter(istestfile, readdir(@__DIR__)) + (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion + @eval @safetestset $file begin + include($file) + end end - end - # test examples - examplepath = joinpath(@__DIR__, "..", "examples") - for (root, _, files) in walkdir(examplepath) - contains(chopprefix(root, @__DIR__), "setup") && continue - for file in filter(isexamplefile, files) - filename = joinpath(root, file) - @eval begin - @safetestset $file begin - $(Expr( - :macrocall, - GlobalRef(Suppressor, Symbol("@suppress")), - LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), - )) + # test examples + examplepath = joinpath(@__DIR__, "..", "examples") + for (root, _, files) in walkdir(examplepath) + contains(chopprefix(root, @__DIR__), "setup") && continue + for file in filter(isexamplefile, files) + filename = joinpath(root, file) + @eval begin + @safetestset $file begin + $( + Expr( + :macrocall, + GlobalRef(Suppressor, Symbol("@suppress")), + LineNumberNode(@__LINE__, @__FILE__), + :(include($filename)), + ) + ) + end + end end - end end - end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 9e2d2ee..87af441 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -5,6 +5,6 @@ using Aqua: Aqua using TensorAlgebra: TensorAlgebra @testset "Code quality (Aqua.jl)" begin - # TODO: fix and re-enable ambiguity checks - Aqua.test_all(TensorAlgebra; ambiguities=false, piracies=false) + # TODO: fix and re-enable ambiguity checks + Aqua.test_all(TensorAlgebra; ambiguities = false, piracies = false) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 3a75cec..ea7be8d 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,318 +5,318 @@ using StableRNGs: StableRNG using TensorOperations: TensorOperations using TensorAlgebra: - Algorithm, - BlockedTuple, - blockedpermvcat, - contract, - contract!, - contractadd!, - length_codomain, - length_domain, - matricize, - permuteblockeddims, - permuteblockeddims!, - tuplemortar, - unmatricize, - unmatricize! + Algorithm, + BlockedTuple, + blockedpermvcat, + contract, + contract!, + contractadd!, + length_codomain, + length_domain, + matricize, + permuteblockeddims, + permuteblockeddims!, + tuplemortar, + unmatricize, + unmatricize! default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin - @testset "misc" begin - t = (1, 2, 3) - bt = tuplemortar(((1, 2), (3,))) - @test length_codomain(t) == 3 - @test length_codomain(bt) == 2 - @test length_domain(t) == 0 - @test length_domain(bt) == 1 - end - - @testset "permuteblockeddims (eltype=$elt)" for elt in elts - a = randn(elt, 2, 3, 4, 5) - a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4))) - @test a_perm == permutedims(a, (3, 1, 2, 4)) - - a = randn(elt, 2, 3, 4, 5) - a_perm = Array{elt}(undef, (4, 2, 3, 5)) - permuteblockeddims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) - @test a_perm == permutedims(a, (3, 1, 2, 4)) - end - @testset "matricize (eltype=$elt)" for elt in elts - a = randn(elt, 2, 3, 4, 5) - - a_fused = matricize(a, blockedpermvcat((1, 2), (3, 4))) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(a, 6, 20) - - a_fused = matricize(a, (1, 2), (3, 4)) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(a, 6, 20) - a_fused = matricize(a, (3, 1), (2, 4)) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) - a_fused = matricize(a, (3, 1, 2), (4,)) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (24, 5)) - a_fused = matricize(a, (..,), (3, 1)) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (15, 8)) - a_fused = matricize(a, (3, 1), (..,)) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) - - a_fused = matricize(a, (), (..,)) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(a, (1, 120)) - a_fused = matricize(a, (..,), ()) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(a, (120, 1)) - - @test_throws MethodError matricize(a, (1, 2), (3,), (4,)) - @test_throws MethodError matricize(a, (1, 2, 3, 4)) - @test_throws ArgumentError matricize(a, blockedpermvcat((1, 2), (3,))) - - v = ones(elt, 2) - a_fused = matricize(v, (1,), ()) - @test eltype(a_fused) === elt - @test a_fused ≈ ones(elt, 2, 1) - a_fused = matricize(v, (), (1,)) - @test eltype(a_fused) === elt - @test a_fused ≈ ones(elt, 1, 2) - - a_fused = matricize(ones(elt), (), ()) - @test eltype(a_fused) === elt - @test a_fused ≈ ones(elt, 1, 1) - end - - @testset "unmatricize (eltype=$elt)" for elt in elts - a0 = randn(elt, 2, 3, 4, 5) - axes0 = axes(a0) - m = reshape(a0, 6, 20) - - a = unmatricize(m, tuplemortar((axes0[1:2], axes0[3:4]))) - @test eltype(a) === elt - @test a ≈ a0 - - a = unmatricize(m, axes0[1:2], axes0[3:4]) - @test eltype(a) === elt - @test a ≈ a0 - - a = unmatricize(m, axes0, blockedpermvcat((1, 2), (3, 4))) - @test eltype(a) === elt - @test a ≈ a0 - - bp = blockedpermvcat((4, 2), (1, 3)) - bpinv = blockedpermvcat((3, 2), (4, 1)) - a = unmatricize(m, map(i -> axes0[i], bp), bpinv) - @test eltype(a) === elt - @test a ≈ permutedims(a0, Tuple(bp)) - - a = similar(a0) - unmatricize!(a, m, blockedpermvcat((1, 2), (3, 4))) - @test a ≈ a0 - - m1 = matricize(a0, bp) - a = unmatricize(m1, axes0, bp) - @test a ≈ a0 - - a1 = permutedims(a0, Tuple(bp)) - a = similar(a1) - unmatricize!(a, m, bpinv) - @test a ≈ a1 - - a = unmatricize(m, (), axes0) - @test eltype(a) === elt - @test a ≈ a0 - - a = unmatricize(m, axes0, ()) - @test eltype(a) === elt - @test a ≈ a0 - - m = randn(elt, 1, 1) - a = unmatricize(m, (), ()) - @test a isa Array{elt,0} - @test a[] == m[1, 1] - - @test_throws ArgumentError unmatricize(m, (), blockedpermvcat((1, 2), (3,))) - @test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,))) - end - - alg_tensoroperations = Algorithm(TensorOperations.StridedBLAS()) - @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts - elt_dest = promote_type(elt1, elt2) - a1 = ones(elt1, (1, 1)) - a2 = ones(elt2, (1, 1)) - a_dest = ones(elt_dest, (1, 1)) - @test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3)) - @test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4)) - @test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3)) - @test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4)) - @test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3)) - - dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) - labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) - for (d1s, d2s, d_dests) in ( - ((1, 2), (1, 2), ()), - ((1, 2), (2, 1), ()), - ((1, 2), (2, 1, 3), (3,)), - ((1, 2, 3), (2, 1), (3,)), - ((1, 2), (2, 3), (1, 3)), - ((1, 2), (2, 3), (3, 1)), - ((2, 1), (2, 3), (3, 1)), - ((1, 2, 3), (2, 3, 4), (1, 4)), - ((1, 2, 3), (2, 3, 4), (4, 1)), - ((3, 2, 1), (4, 2, 3), (4, 1)), - ((1, 2, 3), (3, 4), (1, 2, 4)), - ((1, 2, 3), (3, 4), (4, 1, 2)), - ((1, 2, 3), (3, 4), (2, 4, 1)), - ((3, 1, 2), (3, 4), (2, 4, 1)), - ((3, 2, 1), (4, 3), (2, 4, 1)), - ((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)), - ((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)), - ) - a1 = randn(elt1, map(i -> dims[i], d1s)) - labels1 = map(i -> labels[i], d1s) - a2 = randn(elt2, map(i -> dims[i], d2s)) - labels2 = map(i -> labels[i], d2s) - labels_dest = map(i -> labels[i], d_dests) - - # Don't specify destination labels - a_dest, labels_dest′ = contract(a1, labels1, a2, labels2) - @test labels_dest′ isa - BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))} - a_dest_tensoroperations, = contract(alg_tensoroperations, a1, labels1, a2, labels2) - @test a_dest ≈ a_dest_tensoroperations - - # Specify destination labels - a_dest = contract(labels_dest, a1, labels1, a2, labels2) - a_dest_tensoroperations = contract( - alg_tensoroperations, labels_dest, a1, labels1, a2, labels2 - ) - @test a_dest ≈ a_dest_tensoroperations - - # Specify with bituple - a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2) - @test a_dest ≈ a_dest_tensoroperations - a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2) - @test a_dest ≈ a_dest_tensoroperations - a_dest = contract(labels_dest′, a1, labels1, a2, labels2) - a_dest_tensoroperations = contract( - alg_tensoroperations, labels_dest′, a1, labels1, a2, labels2 - ) - @test a_dest ≈ a_dest_tensoroperations - - # Specify α and β - # TODO: Using random `α`, `β` causing - # random test failures, investigate why. - α = elt_dest(1.2) # randn(elt_dest) - β = elt_dest(2.4) # randn(elt_dest) - a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) - a_dest = copy(a_dest_init) - contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) - a_dest_tensoroperations = copy(a_dest_init) - contractadd!( - alg_tensoroperations, - a_dest_tensoroperations, - labels_dest, - a1, - labels1, - a2, - labels2, - α, - β, - ) - ## Here we loosened the tolerance because of some floating point roundoff issue. - ## with Float32 numbers - @test a_dest ≈ a_dest_tensoroperations rtol = 50 * default_rtol(elt_dest) + @testset "misc" begin + t = (1, 2, 3) + bt = tuplemortar(((1, 2), (3,))) + @test length_codomain(t) == 3 + @test length_codomain(bt) == 2 + @test length_domain(t) == 0 + @test length_domain(bt) == 1 + end + + @testset "permuteblockeddims (eltype=$elt)" for elt in elts + a = randn(elt, 2, 3, 4, 5) + a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn(elt, 2, 3, 4, 5) + a_perm = Array{elt}(undef, (4, 2, 3, 5)) + permuteblockeddims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + end + @testset "matricize (eltype=$elt)" for elt in elts + a = randn(elt, 2, 3, 4, 5) + + a_fused = matricize(a, blockedpermvcat((1, 2), (3, 4))) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, 6, 20) + + a_fused = matricize(a, (1, 2), (3, 4)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, 6, 20) + a_fused = matricize(a, (3, 1), (2, 4)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + a_fused = matricize(a, (3, 1, 2), (4,)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (24, 5)) + a_fused = matricize(a, (..,), (3, 1)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (15, 8)) + a_fused = matricize(a, (3, 1), (..,)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + + a_fused = matricize(a, (), (..,)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, (1, 120)) + a_fused = matricize(a, (..,), ()) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, (120, 1)) + + @test_throws MethodError matricize(a, (1, 2), (3,), (4,)) + @test_throws MethodError matricize(a, (1, 2, 3, 4)) + @test_throws ArgumentError matricize(a, blockedpermvcat((1, 2), (3,))) + + v = ones(elt, 2) + a_fused = matricize(v, (1,), ()) + @test eltype(a_fused) === elt + @test a_fused ≈ ones(elt, 2, 1) + a_fused = matricize(v, (), (1,)) + @test eltype(a_fused) === elt + @test a_fused ≈ ones(elt, 1, 2) + + a_fused = matricize(ones(elt), (), ()) + @test eltype(a_fused) === elt + @test a_fused ≈ ones(elt, 1, 1) + end + + @testset "unmatricize (eltype=$elt)" for elt in elts + a0 = randn(elt, 2, 3, 4, 5) + axes0 = axes(a0) + m = reshape(a0, 6, 20) + + a = unmatricize(m, tuplemortar((axes0[1:2], axes0[3:4]))) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0[1:2], axes0[3:4]) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0, blockedpermvcat((1, 2), (3, 4))) + @test eltype(a) === elt + @test a ≈ a0 + + bp = blockedpermvcat((4, 2), (1, 3)) + bpinv = blockedpermvcat((3, 2), (4, 1)) + a = unmatricize(m, map(i -> axes0[i], bp), bpinv) + @test eltype(a) === elt + @test a ≈ permutedims(a0, Tuple(bp)) + + a = similar(a0) + unmatricize!(a, m, blockedpermvcat((1, 2), (3, 4))) + @test a ≈ a0 + + m1 = matricize(a0, bp) + a = unmatricize(m1, axes0, bp) + @test a ≈ a0 + + a1 = permutedims(a0, Tuple(bp)) + a = similar(a1) + unmatricize!(a, m, bpinv) + @test a ≈ a1 + + a = unmatricize(m, (), axes0) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0, ()) + @test eltype(a) === elt + @test a ≈ a0 + + m = randn(elt, 1, 1) + a = unmatricize(m, (), ()) + @test a isa Array{elt, 0} + @test a[] == m[1, 1] + + @test_throws ArgumentError unmatricize(m, (), blockedpermvcat((1, 2), (3,))) + @test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,))) + end + + alg_tensoroperations = Algorithm(TensorOperations.StridedBLAS()) + @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts + elt_dest = promote_type(elt1, elt2) + a1 = ones(elt1, (1, 1)) + a2 = ones(elt2, (1, 1)) + a_dest = ones(elt_dest, (1, 1)) + @test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3)) + @test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4)) + @test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3)) + @test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4)) + @test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3)) + + dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) + labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) + for (d1s, d2s, d_dests) in ( + ((1, 2), (1, 2), ()), + ((1, 2), (2, 1), ()), + ((1, 2), (2, 1, 3), (3,)), + ((1, 2, 3), (2, 1), (3,)), + ((1, 2), (2, 3), (1, 3)), + ((1, 2), (2, 3), (3, 1)), + ((2, 1), (2, 3), (3, 1)), + ((1, 2, 3), (2, 3, 4), (1, 4)), + ((1, 2, 3), (2, 3, 4), (4, 1)), + ((3, 2, 1), (4, 2, 3), (4, 1)), + ((1, 2, 3), (3, 4), (1, 2, 4)), + ((1, 2, 3), (3, 4), (4, 1, 2)), + ((1, 2, 3), (3, 4), (2, 4, 1)), + ((3, 1, 2), (3, 4), (2, 4, 1)), + ((3, 2, 1), (4, 3), (2, 4, 1)), + ((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)), + ((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)), + ) + a1 = randn(elt1, map(i -> dims[i], d1s)) + labels1 = map(i -> labels[i], d1s) + a2 = randn(elt2, map(i -> dims[i], d2s)) + labels2 = map(i -> labels[i], d2s) + labels_dest = map(i -> labels[i], d_dests) + + # Don't specify destination labels + a_dest, labels_dest′ = contract(a1, labels1, a2, labels2) + @test labels_dest′ isa + BlockedTuple{2, (length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))} + a_dest_tensoroperations, = contract(alg_tensoroperations, a1, labels1, a2, labels2) + @test a_dest ≈ a_dest_tensoroperations + + # Specify destination labels + a_dest = contract(labels_dest, a1, labels1, a2, labels2) + a_dest_tensoroperations = contract( + alg_tensoroperations, labels_dest, a1, labels1, a2, labels2 + ) + @test a_dest ≈ a_dest_tensoroperations + + # Specify with bituple + a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2) + @test a_dest ≈ a_dest_tensoroperations + a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2) + @test a_dest ≈ a_dest_tensoroperations + a_dest = contract(labels_dest′, a1, labels1, a2, labels2) + a_dest_tensoroperations = contract( + alg_tensoroperations, labels_dest′, a1, labels1, a2, labels2 + ) + @test a_dest ≈ a_dest_tensoroperations + + # Specify α and β + # TODO: Using random `α`, `β` causing + # random test failures, investigate why. + α = elt_dest(1.2) # randn(elt_dest) + β = elt_dest(2.4) # randn(elt_dest) + a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) + a_dest = copy(a_dest_init) + contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) + a_dest_tensoroperations = copy(a_dest_init) + contractadd!( + alg_tensoroperations, + a_dest_tensoroperations, + labels_dest, + a1, + labels1, + a2, + labels2, + α, + β, + ) + ## Here we loosened the tolerance because of some floating point roundoff issue. + ## with Float32 numbers + @test a_dest ≈ a_dest_tensoroperations rtol = 50 * default_rtol(elt_dest) + end + end + @testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, + elt2 in elts + + elt_dest = promote_type(elt1, elt2) + + rng = StableRNG(123) + a1 = randn(rng, elt1, 2, 3) + a2 = randn(rng, elt2, 4, 5) + + a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l")) + @test labels == tuplemortar((("i", "j"), ("k", "l"))) + @test eltype(a_dest) === elt_dest + @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) + + a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) + @test eltype(a_dest) === elt_dest + @test a_dest ≈ permutedims( + reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4) + ) + + a_dest = zeros(elt_dest, 2, 5, 3, 4) + contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) + @test a_dest ≈ permutedims( + reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) + ) + end + @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, + elt2 in elts + + elt_dest = promote_type(elt1, elt2) + + rng = StableRNG(123) + a = randn(rng, elt1, (2, 3, 4, 5)) + s = randn(rng, elt2, ()) + t = randn(rng, elt2, ()) + + labels_a = ("i", "j", "k", "l") + + # Array-scalar contraction. + a_dest, labels_dest = contract(a, labels_a, s, ()) + @test labels_dest == tuplemortar((labels_a, ())) + @test a_dest ≈ a * s[] + + # Scalar-array contraction. + a_dest, labels_dest = contract(s, (), a, labels_a) + @test labels_dest == tuplemortar(((), labels_a)) + @test a_dest ≈ a * s[] + + # Scalar-scalar contraction. + a_dest, labels_dest = contract(s, (), t, ()) + @test labels_dest == tuplemortar(((), ())) + @test a_dest[] ≈ s[] * t[] + + # Specify output labels. + labels_dest_example = ("j", "l", "i", "k") + size_dest_example = (3, 5, 2, 4) + + # Array-scalar contraction. + a_dest = contract(labels_dest_example, a, labels_a, s, ()) + @test size(a_dest) == size_dest_example + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-array contraction. + a_dest = contract(labels_dest_example, s, (), a, labels_a) + @test size(a_dest) == size_dest_example + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-scalar contraction. + a_dest = contract((), s, (), t, ()) + @test size(a_dest) == () + @test a_dest[] ≈ s[] * t[] + + # Array-scalar contraction. + a_dest = zeros(elt_dest, size_dest_example) + contract!(a_dest, labels_dest_example, a, labels_a, s, ()) + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-array contraction. + a_dest = zeros(elt_dest, size_dest_example) + contract!(a_dest, labels_dest_example, s, (), a, labels_a) + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-scalar contraction. + a_dest = zeros(elt_dest, ()) + contract!(a_dest, (), s, (), t, ()) + @test a_dest[] ≈ s[] * t[] end - end - @testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, - elt2 in elts - - elt_dest = promote_type(elt1, elt2) - - rng = StableRNG(123) - a1 = randn(rng, elt1, 2, 3) - a2 = randn(rng, elt2, 4, 5) - - a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l")) - @test labels == tuplemortar((("i", "j"), ("k", "l"))) - @test eltype(a_dest) === elt_dest - @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) - - a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) - @test eltype(a_dest) === elt_dest - @test a_dest ≈ permutedims( - reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4) - ) - - a_dest = zeros(elt_dest, 2, 5, 3, 4) - contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) - @test a_dest ≈ permutedims( - reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) - ) - end - @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, - elt2 in elts - - elt_dest = promote_type(elt1, elt2) - - rng = StableRNG(123) - a = randn(rng, elt1, (2, 3, 4, 5)) - s = randn(rng, elt2, ()) - t = randn(rng, elt2, ()) - - labels_a = ("i", "j", "k", "l") - - # Array-scalar contraction. - a_dest, labels_dest = contract(a, labels_a, s, ()) - @test labels_dest == tuplemortar((labels_a, ())) - @test a_dest ≈ a * s[] - - # Scalar-array contraction. - a_dest, labels_dest = contract(s, (), a, labels_a) - @test labels_dest == tuplemortar(((), labels_a)) - @test a_dest ≈ a * s[] - - # Scalar-scalar contraction. - a_dest, labels_dest = contract(s, (), t, ()) - @test labels_dest == tuplemortar(((), ())) - @test a_dest[] ≈ s[] * t[] - - # Specify output labels. - labels_dest_example = ("j", "l", "i", "k") - size_dest_example = (3, 5, 2, 4) - - # Array-scalar contraction. - a_dest = contract(labels_dest_example, a, labels_a, s, ()) - @test size(a_dest) == size_dest_example - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] - - # Scalar-array contraction. - a_dest = contract(labels_dest_example, s, (), a, labels_a) - @test size(a_dest) == size_dest_example - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] - - # Scalar-scalar contraction. - a_dest = contract((), s, (), t, ()) - @test size(a_dest) == () - @test a_dest[] ≈ s[] * t[] - - # Array-scalar contraction. - a_dest = zeros(elt_dest, size_dest_example) - contract!(a_dest, labels_dest_example, a, labels_a, s, ()) - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] - - # Scalar-array contraction. - a_dest = zeros(elt_dest, size_dest_example) - contract!(a_dest, labels_dest_example, s, (), a, labels_a) - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] - - # Scalar-scalar contraction. - a_dest = zeros(elt_dest, ()) - contract!(a_dest, (), s, (), t, ()) - @test a_dest[] ≈ s[] * t[] - end end diff --git a/test/test_blockarrays_contract.jl b/test/test_blockarrays_contract.jl index 224a102..8dada0f 100644 --- a/test/test_blockarrays_contract.jl +++ b/test/test_blockarrays_contract.jl @@ -4,112 +4,112 @@ using TensorAlgebra: contract using Test: @test, @testset function randn_blockdiagonal(elt::Type, axes::Tuple) - a = zeros(elt, axes) - blockdiaglength = minimum(blocksize(a)) - for i in 1:blockdiaglength - b = Block(ntuple(Returns(i), ndims(a))) - a[b] = randn!(a[b]) - end - return a + a = zeros(elt, axes) + blockdiaglength = minimum(blocksize(a)) + for i in 1:blockdiaglength + b = Block(ntuple(Returns(i), ndims(a))) + a[b] = randn!(a[b]) + end + return a end const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "`contract` blocked arrays (eltype=$elt)" for elt in elts - d = blockedrange([2, 3]) - a1 = randn_blockdiagonal(elt, (d, d, d, d)) - a2 = randn_blockdiagonal(elt, (d, d, d, d)) - a3 = randn_blockdiagonal(elt, (d, d)) - a1_dense = convert(Array, a1) - a2_dense = convert(Array, a2) - a3_dense = convert(Array, a3) + d = blockedrange([2, 3]) + a1 = randn_blockdiagonal(elt, (d, d, d, d)) + a2 = randn_blockdiagonal(elt, (d, d, d, d)) + a3 = randn_blockdiagonal(elt, (d, d)) + a1_dense = convert(Array, a1) + a2_dense = convert(Array, a2) + a3_dense = convert(Array, a3) - @testset "BlockedArray" begin - # matrix matrix - a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) - a_dest_dense, dimnames_dest_dense = contract( - a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) - ) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockedArray{elt} - @test a_dest ≈ a_dest_dense + @testset "BlockedArray" begin + # matrix matrix + a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) + a_dest_dense, dimnames_dest_dense = contract( + a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) + ) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockedArray{elt} + @test a_dest ≈ a_dest_dense - # matrix vector - a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockedArray{elt} - @test a_dest ≈ a_dest_dense + # matrix vector + a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) + a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockedArray{elt} + @test a_dest ≈ a_dest_dense - # vector matrix - a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockedArray{elt} - @test a_dest ≈ a_dest_dense + # vector matrix + a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockedArray{elt} + @test a_dest ≈ a_dest_dense - # vector vector - a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockedArray{elt,0} - @test a_dest ≈ a_dest_dense + # vector vector + a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockedArray{elt, 0} + @test a_dest ≈ a_dest_dense - # outer product - a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockedArray{elt} - @test a_dest ≈ a_dest_dense - end + # outer product + a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockedArray{elt} + @test a_dest ≈ a_dest_dense + end - @testset "BlockArray" begin - a1, a3, a3 = BlockArray.((a1, a2, a3)) + @testset "BlockArray" begin + a1, a3, a3 = BlockArray.((a1, a2, a3)) - # matrix matrix - a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) - a_dest_dense, dimnames_dest_dense = contract( - a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) - ) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense + # matrix matrix + a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) + a_dest_dense, dimnames_dest_dense = contract( + a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) + ) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockArray{elt} + @test a_dest ≈ a_dest_dense - # matrix vector - a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense + # matrix vector + a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) + a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockArray{elt} + @test a_dest ≈ a_dest_dense - # vector matrix - a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense + # vector matrix + a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockArray{elt} + @test a_dest ≈ a_dest_dense - # vector vector - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) - a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt,0} - @test a_dest ≈ a_dest_dense + # vector vector + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) + a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockArray{elt, 0} + @test a_dest ≈ a_dest_dense - # outer product - a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense - end + # outer product + a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) + @test dimnames_dest == dimnames_dest_dense + @test size(a_dest) == size(a_dest_dense) + @test a_dest isa BlockArray{elt} + @test a_dest ≈ a_dest_dense + end end diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index 016011f..9290d32 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -5,166 +5,166 @@ using EllipsisNotation: var".." using TestExtras: @constinferred using TensorAlgebra: - BlockedPermutation, - BlockedTrivialPermutation, - BlockedTuple, - blockedperm, - blockedperm_indexin, - blockpermute, - blockedtrivialperm, - blockedpermvcat, - permmortar, - trivialperm, - tuplemortar + BlockedPermutation, + BlockedTrivialPermutation, + BlockedTuple, + blockedperm, + blockedperm_indexin, + blockpermute, + blockedtrivialperm, + blockedpermvcat, + permmortar, + trivialperm, + tuplemortar @testset "BlockedPermutation" begin - p = @constinferred permmortar(((3, 4, 5), (2, 1))) - @test Tuple(p) === (3, 4, 5, 2, 1) - @test isperm(p) - @test length(p) == 5 - @test blocks(p) == ((3, 4, 5), (2, 1)) - @test blocklength(p) == 2 - @test blocklengths(p) == (3, 2) - @test blockfirsts(p) == (1, 4) - @test blocklasts(p) == (3, 5) - @test p == (@constinferred blockedpermvcat((3, 4, 5), (2, 1))) - @test p == blockedperm((3, 4, 5, 2, 1), (3, 2)) - @test p == (@constinferred blockedperm((3, 4, 5, 2, 1), Val((3, 2)))) - @test (@constinferred invperm(p)) == blockedpermvcat((5, 4, 1), (2, 3)) - @test p isa BlockedPermutation{2} - - flat = (3, 4, 5, 2, 1) - @test_throws DimensionMismatch BlockedPermutation{2,(1, 2, 2)}(flat) - @test_throws DimensionMismatch BlockedPermutation{3,(1, 2, 3)}(flat) - @test_throws DimensionMismatch BlockedPermutation{3,(-1, 3, 3)}(flat) - @test_throws AssertionError blockedpermvcat((3, 5), (2, 1)) - @test_throws AssertionError blockedpermvcat((0, 1), (2, 3)) - @test_throws AssertionError blockedpermvcat((0,)) - @test_throws AssertionError blockedpermvcat((2,)) - - # Empty block. - p = @constinferred blockedpermvcat((3, 2), (), (1,)) - @test Tuple(p) === (3, 2, 1) - @test isperm(p) - @test length(p) == 3 - @test blocks(p) == ((3, 2), (), (1,)) - @test blocklength(p) == 3 - @test blocklengths(p) == (2, 0, 1) - @test blockfirsts(p) == (1, 3, 3) - @test blocklasts(p) == (2, 2, 3) - @test invperm(p) == blockedpermvcat((3, 2), (), (1,)) - @test p isa BlockedPermutation{3} - - p = @constinferred blockedpermvcat((), ()) - @test Tuple(p) === () - @test blocklength(p) == 2 - @test blocklengths(p) == (0, 0) - @test isperm(p) - @test length(p) == 0 - @test blocks(p) == ((), ()) - @test p isa BlockedPermutation{2} - - p = @constinferred blockedpermvcat() - @test Tuple(p) === () - @test blocklength(p) == 0 - @test blocklengths(p) == () - @test isperm(p) - @test length(p) == 0 - @test blocks(p) == () - @test p isa BlockedPermutation{0} - - p = blockedpermvcat((3, 2), (), (1,)) - bt = tuplemortar(((3, 2), (), (1,))) - @test (@constinferred BlockedTuple(p)) == bt - @test (@constinferred map(identity, p)) == bt - @test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,))) - @test (@constinferred p .+ bt) == tuplemortar(((6, 4), (), (2,))) - @test (@constinferred bt .+ p) == tuplemortar(((6, 4), (), (2,))) - @test (@constinferred blockedperm(p)) == p - @test (@constinferred blockedperm(bt)) == p - - @test_throws ArgumentError blockedpermvcat((1, 3), (2, 4); length=Val(6)) - - # Split collection into `BlockedPermutation`. - p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) - @test p == blockedpermvcat((3, 1), (2, 4)) - - # Singleton dimensions. - p = @constinferred blockedpermvcat((2, 3), 1) - @test p == blockedpermvcat((2, 3), (1,)) - - # First dimensions are unspecified. - p = blockedpermvcat(.., (4, 3)) - @test p == blockedpermvcat((1,), (2,), (4, 3)) - # Specify length - p = @constinferred blockedpermvcat(.., (4, 3); length=Val(6)) - @test p == blockedpermvcat((1,), (2,), (5,), (6,), (4, 3)) - - # Last dimensions are unspecified. - p = blockedpermvcat((4, 3), ..) - @test p == blockedpermvcat((4, 3), (1,), (2,)) - # Specify length - p = @constinferred blockedpermvcat((4, 3), ..; length=Val(6)) - @test p == blockedpermvcat((4, 3), (1,), (2,), (5,), (6,)) - - # Middle dimensions are unspecified. - p = blockedpermvcat((4, 3), .., 1) - @test p == blockedpermvcat((4, 3), (2,), (1,)) - # Specify length - p = @constinferred blockedpermvcat((4, 3), .., 1; length=Val(6)) - @test p == blockedpermvcat((4, 3), (2,), (5,), (6,), (1,)) - - # No dimensions are unspecified. - p = blockedpermvcat((3, 2), .., 1) - @test p == blockedpermvcat((3, 2), (1,)) - - # same with (..,) instead of .. - p = blockedpermvcat((..,), (4, 3)) - @test p == blockedpermvcat((1, 2), (4, 3)) - p = @constinferred blockedpermvcat((..,), (4, 3); length=Val(6)) - @test p == blockedpermvcat((1, 2, 5, 6), (4, 3)) - - p = blockedpermvcat((4, 3), (..,)) - @test p == blockedpermvcat((4, 3), (1, 2)) - p = @constinferred blockedpermvcat((4, 3), (..,); length=Val(6)) - @test p == blockedpermvcat((4, 3), (1, 2, 5, 6)) - - p = blockedpermvcat((4, 3), (..,), 1) - @test p == blockedpermvcat((4, 3), (2,), (1,)) - p = @constinferred blockedpermvcat((4, 3), (..,), 1; length=Val(6)) - @test p == blockedpermvcat((4, 3), (2, 5, 6), (1,)) - - p = blockedpermvcat((3, 2), (..,), 1) - @test p == blockedpermvcat((3, 2), (), (1,)) - - # blockpermute - t = (1, 2, 3, 4) - pblocks = tuplemortar(((4, 3), (), (1, 2))) - p = blockedperm(pblocks) - @test (@constinferred blockpermute(t, p)) isa BlockedTuple{3,(2, 0, 2),NTuple{4,Int64}} - @test blockpermute(t, p) == pblocks - @test t[p] == pblocks - @test pblocks[p] == tuplemortar(((2, 1), (), (4, 3))) - @test p[p] == tuplemortar(((2, 1), (), (4, 3))) + p = @constinferred permmortar(((3, 4, 5), (2, 1))) + @test Tuple(p) === (3, 4, 5, 2, 1) + @test isperm(p) + @test length(p) == 5 + @test blocks(p) == ((3, 4, 5), (2, 1)) + @test blocklength(p) == 2 + @test blocklengths(p) == (3, 2) + @test blockfirsts(p) == (1, 4) + @test blocklasts(p) == (3, 5) + @test p == (@constinferred blockedpermvcat((3, 4, 5), (2, 1))) + @test p == blockedperm((3, 4, 5, 2, 1), (3, 2)) + @test p == (@constinferred blockedperm((3, 4, 5, 2, 1), Val((3, 2)))) + @test (@constinferred invperm(p)) == blockedpermvcat((5, 4, 1), (2, 3)) + @test p isa BlockedPermutation{2} + + flat = (3, 4, 5, 2, 1) + @test_throws DimensionMismatch BlockedPermutation{2, (1, 2, 2)}(flat) + @test_throws DimensionMismatch BlockedPermutation{3, (1, 2, 3)}(flat) + @test_throws DimensionMismatch BlockedPermutation{3, (-1, 3, 3)}(flat) + @test_throws AssertionError blockedpermvcat((3, 5), (2, 1)) + @test_throws AssertionError blockedpermvcat((0, 1), (2, 3)) + @test_throws AssertionError blockedpermvcat((0,)) + @test_throws AssertionError blockedpermvcat((2,)) + + # Empty block. + p = @constinferred blockedpermvcat((3, 2), (), (1,)) + @test Tuple(p) === (3, 2, 1) + @test isperm(p) + @test length(p) == 3 + @test blocks(p) == ((3, 2), (), (1,)) + @test blocklength(p) == 3 + @test blocklengths(p) == (2, 0, 1) + @test blockfirsts(p) == (1, 3, 3) + @test blocklasts(p) == (2, 2, 3) + @test invperm(p) == blockedpermvcat((3, 2), (), (1,)) + @test p isa BlockedPermutation{3} + + p = @constinferred blockedpermvcat((), ()) + @test Tuple(p) === () + @test blocklength(p) == 2 + @test blocklengths(p) == (0, 0) + @test isperm(p) + @test length(p) == 0 + @test blocks(p) == ((), ()) + @test p isa BlockedPermutation{2} + + p = @constinferred blockedpermvcat() + @test Tuple(p) === () + @test blocklength(p) == 0 + @test blocklengths(p) == () + @test isperm(p) + @test length(p) == 0 + @test blocks(p) == () + @test p isa BlockedPermutation{0} + + p = blockedpermvcat((3, 2), (), (1,)) + bt = tuplemortar(((3, 2), (), (1,))) + @test (@constinferred BlockedTuple(p)) == bt + @test (@constinferred map(identity, p)) == bt + @test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,))) + @test (@constinferred p .+ bt) == tuplemortar(((6, 4), (), (2,))) + @test (@constinferred bt .+ p) == tuplemortar(((6, 4), (), (2,))) + @test (@constinferred blockedperm(p)) == p + @test (@constinferred blockedperm(bt)) == p + + @test_throws ArgumentError blockedpermvcat((1, 3), (2, 4); length = Val(6)) + + # Split collection into `BlockedPermutation`. + p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) + @test p == blockedpermvcat((3, 1), (2, 4)) + + # Singleton dimensions. + p = @constinferred blockedpermvcat((2, 3), 1) + @test p == blockedpermvcat((2, 3), (1,)) + + # First dimensions are unspecified. + p = blockedpermvcat(.., (4, 3)) + @test p == blockedpermvcat((1,), (2,), (4, 3)) + # Specify length + p = @constinferred blockedpermvcat(.., (4, 3); length = Val(6)) + @test p == blockedpermvcat((1,), (2,), (5,), (6,), (4, 3)) + + # Last dimensions are unspecified. + p = blockedpermvcat((4, 3), ..) + @test p == blockedpermvcat((4, 3), (1,), (2,)) + # Specify length + p = @constinferred blockedpermvcat((4, 3), ..; length = Val(6)) + @test p == blockedpermvcat((4, 3), (1,), (2,), (5,), (6,)) + + # Middle dimensions are unspecified. + p = blockedpermvcat((4, 3), .., 1) + @test p == blockedpermvcat((4, 3), (2,), (1,)) + # Specify length + p = @constinferred blockedpermvcat((4, 3), .., 1; length = Val(6)) + @test p == blockedpermvcat((4, 3), (2,), (5,), (6,), (1,)) + + # No dimensions are unspecified. + p = blockedpermvcat((3, 2), .., 1) + @test p == blockedpermvcat((3, 2), (1,)) + + # same with (..,) instead of .. + p = blockedpermvcat((..,), (4, 3)) + @test p == blockedpermvcat((1, 2), (4, 3)) + p = @constinferred blockedpermvcat((..,), (4, 3); length = Val(6)) + @test p == blockedpermvcat((1, 2, 5, 6), (4, 3)) + + p = blockedpermvcat((4, 3), (..,)) + @test p == blockedpermvcat((4, 3), (1, 2)) + p = @constinferred blockedpermvcat((4, 3), (..,); length = Val(6)) + @test p == blockedpermvcat((4, 3), (1, 2, 5, 6)) + + p = blockedpermvcat((4, 3), (..,), 1) + @test p == blockedpermvcat((4, 3), (2,), (1,)) + p = @constinferred blockedpermvcat((4, 3), (..,), 1; length = Val(6)) + @test p == blockedpermvcat((4, 3), (2, 5, 6), (1,)) + + p = blockedpermvcat((3, 2), (..,), 1) + @test p == blockedpermvcat((3, 2), (), (1,)) + + # blockpermute + t = (1, 2, 3, 4) + pblocks = tuplemortar(((4, 3), (), (1, 2))) + p = blockedperm(pblocks) + @test (@constinferred blockpermute(t, p)) isa BlockedTuple{3, (2, 0, 2), NTuple{4, Int64}} + @test blockpermute(t, p) == pblocks + @test t[p] == pblocks + @test pblocks[p] == tuplemortar(((2, 1), (), (4, 3))) + @test p[p] == tuplemortar(((2, 1), (), (4, 3))) end @testset "BlockedTrivialPermutation" begin - tp = blockedtrivialperm((2, 0, 1)) - - @test tp isa BlockedTrivialPermutation{3} - @test Tuple(tp) == (1, 2, 3) - @test blocklength(tp) == 3 - @test blocklengths(tp) == (2, 0, 1) - @test trivialperm(blockedpermvcat((3, 2), (), (1,))) == tp - - bt = tuplemortar(((1, 2), (), (3,))) - @test (@constinferred BlockedTuple(tp)) == bt - @test (@constinferred blocks(tp)) == blocks(bt) - @test (@constinferred map(identity, tp)) == bt - @test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,))) - @test (@constinferred tp .+ Tuple(tp)) == tuplemortar(((2, 4), (), (6,))) - @test (@constinferred tp .+ BlockedTuple(tp)) == tuplemortar(((2, 4), (), (6,))) - @test (@constinferred blockedperm(tp)) == tp - @test (@constinferred trivialperm(tp)) == tp - @test (@constinferred trivialperm(bt)) == tp + tp = blockedtrivialperm((2, 0, 1)) + + @test tp isa BlockedTrivialPermutation{3} + @test Tuple(tp) == (1, 2, 3) + @test blocklength(tp) == 3 + @test blocklengths(tp) == (2, 0, 1) + @test trivialperm(blockedpermvcat((3, 2), (), (1,))) == tp + + bt = tuplemortar(((1, 2), (), (3,))) + @test (@constinferred BlockedTuple(tp)) == bt + @test (@constinferred blocks(tp)) == blocks(bt) + @test (@constinferred map(identity, tp)) == bt + @test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,))) + @test (@constinferred tp .+ Tuple(tp)) == tuplemortar(((2, 4), (), (6,))) + @test (@constinferred tp .+ BlockedTuple(tp)) == tuplemortar(((2, 4), (), (6,))) + @test (@constinferred blockedperm(tp)) == tp + @test (@constinferred trivialperm(tp)) == tp + @test (@constinferred trivialperm(bt)) == tp end diff --git a/test/test_blockedtuple.jl b/test/test_blockedtuple.jl index b19b8ab..6edb493 100644 --- a/test/test_blockedtuple.jl +++ b/test/test_blockedtuple.jl @@ -1,124 +1,124 @@ using Test: @test, @test_throws, @testset using BlockArrays: - Block, BlockVector, blocklength, blocklengths, blockedrange, blockisequal, blocks + Block, BlockVector, blocklength, blocklengths, blockedrange, blockisequal, blocks using TestExtras: @constinferred using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar @testset "BlockedTuple" begin - flat = (true, 'a', 2, "b", 3.0) - divs = (1, 2, 2) - - bt = @constinferred BlockedTuple{3,divs}(flat) - @test bt isa BlockedTuple{3} - @test (@constinferred blockeachindex(bt)) == (Block(1), Block(2), Block(3)) - - @test (@constinferred Tuple(bt)) == flat - @test (@constinferred tuplemortar(((true,), ('a', 2), ("b", 3.0)))) == bt - @test BlockedTuple(flat, divs) == bt - @test (@constinferred BlockedTuple(bt)) == bt - @test blocklength(bt) == 3 - @test blocklengths(bt) == (1, 2, 2) - @test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0)) - - @test (@constinferred bt[1]) == true - @test (@constinferred bt[2]) == 'a' - @test (@constinferred map(identity, bt)) == bt - - # it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block - @test bt[Block(1)] == blocks(bt)[1] - @test bt[Block(2)] == blocks(bt)[2] - @test bt[Block(1):Block(2)] == tuplemortar(((true,), ('a', 2))) - @test bt[Block(2)[1:2]] == ('a', 2) - @test bt[2:4] == ('a', 2, "b") - - @test firstindex(bt) == 1 - @test lastindex(bt) == 5 - @test length(bt) == 5 - - @test iterate(bt) == (1, 2) - @test iterate(bt, 2) == ('a', 3) - @test blockisequal(only(axes(bt)), blockedrange([1, 2, 2])) - - @test_throws DimensionMismatch BlockedTuple{2,(1, 2, 2)}(flat) - @test_throws DimensionMismatch BlockedTuple{3,(1, 2, 3)}(flat) - @test_throws DimensionMismatch BlockedTuple{3,(-1, 3, 3)}(flat) - - bt = tuplemortar(((1,), (4, 2), (5, 3))) - @test bt isa BlockedTuple - @test Tuple(bt) == (1, 4, 2, 5, 3) - @test blocklengths(bt) == (1, 2, 2) - @test (@constinferred deepcopy(bt)) == bt - - @test (@constinferred map(n -> n + 1, bt)) == - BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1) - @test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) == - BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1) - @test (@constinferred bt .+ tuplemortar(((1,), (1, 1, 1), (1,)))) isa - BlockedTuple{4,(1, 2, 1, 1),NTuple{5,Int64}} - @test bt .+ tuplemortar(((1,), (1, 1, 1), (1,))) == - tuplemortar(((2,), (5, 3), (6,), (4,))) - - bt = tuplemortar(((1:2, 1:2), (1:3,))) - @test length.(bt) == tuplemortar(((2, 2), (3,))) - @test length.(length.(bt)) == tuplemortar(((1, 1), (1,))) - - bt = tuplemortar(((1,), (2,))) - @test (@constinferred bt .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - @test (bt .== bt) == tuplemortar(((true,), (true,))) - @test (@constinferred bt .== tuplemortar(((1, 2),))) isa - BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - @test (bt .== tuplemortar(((1, 2),))) == tuplemortar(((true,), (true,))) - @test_throws DimensionMismatch bt .== tuplemortar(((1,), (2,), (3,))) - @test (@constinferred bt .== (1, 2)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - @test (bt .== (1, 2)) == tuplemortar(((true,), (true,))) - @test_throws DimensionMismatch bt .== (1, 2, 3) - @test (@constinferred bt .== 1) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - @test (bt .== 1) == tuplemortar(((true,), (false,))) - @test (@constinferred bt .== (1,)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - - @test (bt .== (1,)) == tuplemortar(((true,), (false,))) - # BlockedTuple .== AbstractVector is not type stable. Requires fix in BlockArrays - @test (bt .== [1, 1]) isa BlockVector{Bool} - @test blocks(bt .== [1, 1]) == [[true], [false]] - @test_throws DimensionMismatch bt .== [1, 2, 3] - - @test (@constinferred (1, 2) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - @test ((1, 2) .== bt) == tuplemortar(((true,), (true,))) - @test_throws DimensionMismatch (1, 2, 3) .== bt - @test (@constinferred 1 .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - @test (1 .== bt) == tuplemortar(((true,), (false,))) - @test (@constinferred (1,) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} - @test ((1,) .== bt) == tuplemortar(((true,), (false,))) - @test ([1, 1] .== bt) isa BlockVector{Bool} - @test blocks([1, 1] .== bt) == [[true], [false]] - - # empty blocks - bt = tuplemortar(((1,), (), (5, 3))) - @test bt isa BlockedTuple{3} - @test Tuple(bt) == (1, 5, 3) - @test blocklengths(bt) == (1, 0, 2) - @test (@constinferred blocks(bt)) == ((1,), (), (5, 3)) - @test blockisequal(only(axes(bt)), blockedrange([1, 0, 2])) - - bt = tuplemortar(((), ())) - @test bt isa BlockedTuple{2} - @test Tuple(bt) == () - @test blocklengths(bt) == (0, 0) - @test (@constinferred blocks(bt)) == ((), ()) - @test blockisequal(only(axes(bt)), blockedrange([0, 0])) - @test bt == bt .+ bt - - bt0 = tuplemortar(()) - bt1 = tuplemortar(((),)) - @test bt0 isa BlockedTuple{0} - @test Tuple(bt0) == () - @test blocklengths(bt0) == () - @test (@constinferred blocks(bt0)) == () - @test blockisequal(only(axes(bt0)), blockedrange(zeros(Int, 0))) - @test bt0 == bt0 - @test bt != bt1 - @test (@constinferred bt0 .+ bt0) == bt0 - @test (@constinferred bt0 .+ bt1) == bt1 + flat = (true, 'a', 2, "b", 3.0) + divs = (1, 2, 2) + + bt = @constinferred BlockedTuple{3, divs}(flat) + @test bt isa BlockedTuple{3} + @test (@constinferred blockeachindex(bt)) == (Block(1), Block(2), Block(3)) + + @test (@constinferred Tuple(bt)) == flat + @test (@constinferred tuplemortar(((true,), ('a', 2), ("b", 3.0)))) == bt + @test BlockedTuple(flat, divs) == bt + @test (@constinferred BlockedTuple(bt)) == bt + @test blocklength(bt) == 3 + @test blocklengths(bt) == (1, 2, 2) + @test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0)) + + @test (@constinferred bt[1]) == true + @test (@constinferred bt[2]) == 'a' + @test (@constinferred map(identity, bt)) == bt + + # it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block + @test bt[Block(1)] == blocks(bt)[1] + @test bt[Block(2)] == blocks(bt)[2] + @test bt[Block(1):Block(2)] == tuplemortar(((true,), ('a', 2))) + @test bt[Block(2)[1:2]] == ('a', 2) + @test bt[2:4] == ('a', 2, "b") + + @test firstindex(bt) == 1 + @test lastindex(bt) == 5 + @test length(bt) == 5 + + @test iterate(bt) == (1, 2) + @test iterate(bt, 2) == ('a', 3) + @test blockisequal(only(axes(bt)), blockedrange([1, 2, 2])) + + @test_throws DimensionMismatch BlockedTuple{2, (1, 2, 2)}(flat) + @test_throws DimensionMismatch BlockedTuple{3, (1, 2, 3)}(flat) + @test_throws DimensionMismatch BlockedTuple{3, (-1, 3, 3)}(flat) + + bt = tuplemortar(((1,), (4, 2), (5, 3))) + @test bt isa BlockedTuple + @test Tuple(bt) == (1, 4, 2, 5, 3) + @test blocklengths(bt) == (1, 2, 2) + @test (@constinferred deepcopy(bt)) == bt + + @test (@constinferred map(n -> n + 1, bt)) == + BlockedTuple{3, blocklengths(bt)}(Tuple(bt) .+ 1) + @test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) == + BlockedTuple{3, blocklengths(bt)}(Tuple(bt) .+ 1) + @test (@constinferred bt .+ tuplemortar(((1,), (1, 1, 1), (1,)))) isa + BlockedTuple{4, (1, 2, 1, 1), NTuple{5, Int64}} + @test bt .+ tuplemortar(((1,), (1, 1, 1), (1,))) == + tuplemortar(((2,), (5, 3), (6,), (4,))) + + bt = tuplemortar(((1:2, 1:2), (1:3,))) + @test length.(bt) == tuplemortar(((2, 2), (3,))) + @test length.(length.(bt)) == tuplemortar(((1, 1), (1,))) + + bt = tuplemortar(((1,), (2,))) + @test (@constinferred bt .== bt) isa BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + @test (bt .== bt) == tuplemortar(((true,), (true,))) + @test (@constinferred bt .== tuplemortar(((1, 2),))) isa + BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + @test (bt .== tuplemortar(((1, 2),))) == tuplemortar(((true,), (true,))) + @test_throws DimensionMismatch bt .== tuplemortar(((1,), (2,), (3,))) + @test (@constinferred bt .== (1, 2)) isa BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + @test (bt .== (1, 2)) == tuplemortar(((true,), (true,))) + @test_throws DimensionMismatch bt .== (1, 2, 3) + @test (@constinferred bt .== 1) isa BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + @test (bt .== 1) == tuplemortar(((true,), (false,))) + @test (@constinferred bt .== (1,)) isa BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + + @test (bt .== (1,)) == tuplemortar(((true,), (false,))) + # BlockedTuple .== AbstractVector is not type stable. Requires fix in BlockArrays + @test (bt .== [1, 1]) isa BlockVector{Bool} + @test blocks(bt .== [1, 1]) == [[true], [false]] + @test_throws DimensionMismatch bt .== [1, 2, 3] + + @test (@constinferred (1, 2) .== bt) isa BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + @test ((1, 2) .== bt) == tuplemortar(((true,), (true,))) + @test_throws DimensionMismatch (1, 2, 3) .== bt + @test (@constinferred 1 .== bt) isa BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + @test (1 .== bt) == tuplemortar(((true,), (false,))) + @test (@constinferred (1,) .== bt) isa BlockedTuple{2, (1, 1), Tuple{Bool, Bool}} + @test ((1,) .== bt) == tuplemortar(((true,), (false,))) + @test ([1, 1] .== bt) isa BlockVector{Bool} + @test blocks([1, 1] .== bt) == [[true], [false]] + + # empty blocks + bt = tuplemortar(((1,), (), (5, 3))) + @test bt isa BlockedTuple{3} + @test Tuple(bt) == (1, 5, 3) + @test blocklengths(bt) == (1, 0, 2) + @test (@constinferred blocks(bt)) == ((1,), (), (5, 3)) + @test blockisequal(only(axes(bt)), blockedrange([1, 0, 2])) + + bt = tuplemortar(((), ())) + @test bt isa BlockedTuple{2} + @test Tuple(bt) == () + @test blocklengths(bt) == (0, 0) + @test (@constinferred blocks(bt)) == ((), ()) + @test blockisequal(only(axes(bt)), blockedrange([0, 0])) + @test bt == bt .+ bt + + bt0 = tuplemortar(()) + bt1 = tuplemortar(((),)) + @test bt0 isa BlockedTuple{0} + @test Tuple(bt0) == () + @test blocklengths(bt0) == () + @test (@constinferred blocks(bt0)) == () + @test blockisequal(only(axes(bt0)), blockedrange(zeros(Int, 0))) + @test bt0 == bt0 + @test bt != bt1 + @test (@constinferred bt0 .+ bt0) == bt0 + @test (@constinferred bt0 .+ bt1) == bt1 end diff --git a/test/test_exports.jl b/test/test_exports.jl index 0c7b00b..51f376e 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -3,48 +3,48 @@ using Test: @test, @testset using TensorAlgebra: TensorAlgebra @testset "Test exports" begin - exports = [ - :TensorAlgebra, - :contract, - :contract!, - :eigen, - :eigvals, - :factorize, - :left_null, - :left_orth, - :left_polar, - :lq, - :orth, - :polar, - :qr, - :right_null, - :right_orth, - :right_polar, - :svd, - :svdvals, - ] - @test issetequal(names(TensorAlgebra), exports) + exports = [ + :TensorAlgebra, + :contract, + :contract!, + :eigen, + :eigvals, + :factorize, + :left_null, + :left_orth, + :left_polar, + :lq, + :orth, + :polar, + :qr, + :right_null, + :right_orth, + :right_polar, + :svd, + :svdvals, + ] + @test issetequal(names(TensorAlgebra), exports) - exports = [ - :MatrixAlgebra, - :eigen, - :eigen!, - :eigvals, - :eigvals!, - :factorize, - :factorize!, - :lq, - :lq!, - :orth, - :orth!, - :polar, - :polar!, - :qr, - :qr!, - :svd, - :svd!, - :svdvals, - :svdvals!, - ] - @test issetequal(names(TensorAlgebra.MatrixAlgebra), exports) + exports = [ + :MatrixAlgebra, + :eigen, + :eigen!, + :eigvals, + :eigvals!, + :factorize, + :factorize!, + :lq, + :lq!, + :orth, + :orth!, + :polar, + :polar!, + :qr, + :qr!, + :svd, + :svd!, + :svdvals, + :svdvals!, + ] + @test issetequal(names(TensorAlgebra.MatrixAlgebra), exports) end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index f23f38d..43284f5 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,22 +1,22 @@ using LinearAlgebra: LinearAlgebra, norm, diag using MatrixAlgebraKit: truncrank using TensorAlgebra: - contract, - eigen, - eigvals, - factorize, - left_null, - left_orth, - left_polar, - lq, - orth, - polar, - qr, - right_null, - right_orth, - right_polar, - svd, - svdvals + contract, + eigen, + eigvals, + factorize, + left_null, + left_orth, + left_polar, + lq, + orth, + polar, + qr, + right_null, + right_orth, + right_polar, + svd, + svdvals using Test: @test, @testset using TestExtras: @constinferred @@ -25,305 +25,305 @@ elts = (Float64, ComplexF64) # QR Decomposition # ---------------- @testset "Full QR ($T)" for T in elts - A = randn(T, 5, 4, 3, 2) - labels_A = (:a, :b, :c, :d) - labels_Q = (:b, :a) - labels_R = (:d, :c) - - Acopy = deepcopy(A) - Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=true) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) - @test A ≈ A′ - @test size(Q, 1) * size(Q, 2) == size(Q, 3) # Q is unitary + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_Q = (:b, :a) + labels_R = (:d, :c) + + Acopy = deepcopy(A) + Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full = true) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) + @test A ≈ A′ + @test size(Q, 1) * size(Q, 2) == size(Q, 3) # Q is unitary end @testset "Compact QR ($T)" for T in elts - A = randn(T, 2, 3, 4, 5) # compact only makes a difference for less columns - labels_A = (:a, :b, :c, :d) - labels_Q = (:b, :a) - labels_R = (:d, :c) - - Acopy = deepcopy(A) - Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=false) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) - @test A ≈ A′ - @test size(Q, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + A = randn(T, 2, 3, 4, 5) # compact only makes a difference for less columns + labels_A = (:a, :b, :c, :d) + labels_Q = (:b, :a) + labels_R = (:d, :c) + + Acopy = deepcopy(A) + Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full = false) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) + @test A ≈ A′ + @test size(Q, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) end # LQ Decomposition # ---------------- @testset "Full LQ ($T)" for T in elts - A = randn(T, 2, 3, 4, 5) - labels_A = (:a, :b, :c, :d) - labels_Q = (:d, :c) - labels_L = (:b, :a) - - Acopy = deepcopy(A) - L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full=true) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) - @test A ≈ A′ - @test size(Q, 1) == size(Q, 2) * size(Q, 3) # Q is unitary + A = randn(T, 2, 3, 4, 5) + labels_A = (:a, :b, :c, :d) + labels_Q = (:d, :c) + labels_L = (:b, :a) + + Acopy = deepcopy(A) + L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full = true) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) + @test A ≈ A′ + @test size(Q, 1) == size(Q, 2) * size(Q, 3) # Q is unitary end @testset "Compact LQ ($T)" for T in elts - A = randn(T, 5, 4, 3, 2) # compact only makes a difference for less rows - labels_A = (:a, :b, :c, :d) - labels_Q = (:d, :c) - labels_L = (:b, :a) - - Acopy = deepcopy(A) - L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full=false) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) - @test A ≈ A′ - @test size(Q, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) # Q is unitary + A = randn(T, 5, 4, 3, 2) # compact only makes a difference for less rows + labels_A = (:a, :b, :c, :d) + labels_Q = (:d, :c) + labels_L = (:b, :a) + + Acopy = deepcopy(A) + L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full = false) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) + @test A ≈ A′ + @test size(Q, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) # Q is unitary end # Eigenvalue Decomposition # ------------------------ @testset "Eigenvalue decomposition ($T)" for T in elts - A = randn(T, 4, 3, 4, 3) # needs to be square - labels_A = (:a, :b, :c, :d) - labels_V = (:b, :a) - labels_V′ = (:d, :c) - - Acopy = deepcopy(A) - # type-unstable because of `ishermitian` difference - D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian=false) - @test A == Acopy # should not have altered initial array - @test eltype(D) == eltype(V) && eltype(D) <: Complex - - AV = contract((:a, :b, :D), A, labels_A, V, (labels_V′..., :D)) - VD = contract((:a, :b, :D), V, (labels_V..., :D′), D, (:D′, :D)) - @test AV ≈ VD - - # type-unstable because of `ishermitian` difference - Dvals = eigvals(A, labels_A, labels_V, labels_V′; ishermitian=false) - @test Dvals ≈ diag(D) - @test eltype(Dvals) <: Complex + A = randn(T, 4, 3, 4, 3) # needs to be square + labels_A = (:a, :b, :c, :d) + labels_V = (:b, :a) + labels_V′ = (:d, :c) + + Acopy = deepcopy(A) + # type-unstable because of `ishermitian` difference + D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian = false) + @test A == Acopy # should not have altered initial array + @test eltype(D) == eltype(V) && eltype(D) <: Complex + + AV = contract((:a, :b, :D), A, labels_A, V, (labels_V′..., :D)) + VD = contract((:a, :b, :D), V, (labels_V..., :D′), D, (:D′, :D)) + @test AV ≈ VD + + # type-unstable because of `ishermitian` difference + Dvals = eigvals(A, labels_A, labels_V, labels_V′; ishermitian = false) + @test Dvals ≈ diag(D) + @test eltype(Dvals) <: Complex end @testset "Hermitian eigenvalue decomposition ($T)" for T in elts - A = randn(T, 12, 12) - A = reshape(A + A', 4, 3, 4, 3) - labels_A = (:a, :b, :c, :d) - labels_V = (:b, :a) - labels_V′ = (:d, :c) - - Acopy = deepcopy(A) - # type-unstable because of `ishermitian` difference - D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian=true) - @test A == Acopy # should not have altered initial array - @test eltype(D) <: Real - @test eltype(V) == eltype(A) - - AV = contract((:a, :b, :D), A, labels_A, V, (labels_V′..., :D)) - VD = contract((:a, :b, :D), V, (labels_V..., :D′), D, (:D′, :D)) - @test AV ≈ VD - - # type-unstable because of `ishermitian` difference - Dvals = eigvals(A, labels_A, labels_V, labels_V′; ishermitian=true) - @test Dvals ≈ diag(D) - @test eltype(Dvals) <: Real + A = randn(T, 12, 12) + A = reshape(A + A', 4, 3, 4, 3) + labels_A = (:a, :b, :c, :d) + labels_V = (:b, :a) + labels_V′ = (:d, :c) + + Acopy = deepcopy(A) + # type-unstable because of `ishermitian` difference + D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian = true) + @test A == Acopy # should not have altered initial array + @test eltype(D) <: Real + @test eltype(V) == eltype(A) + + AV = contract((:a, :b, :D), A, labels_A, V, (labels_V′..., :D)) + VD = contract((:a, :b, :D), V, (labels_V..., :D′), D, (:D′, :D)) + @test AV ≈ VD + + # type-unstable because of `ishermitian` difference + Dvals = eigvals(A, labels_A, labels_V, labels_V′; ishermitian = true) + @test Dvals ≈ diag(D) + @test eltype(Dvals) <: Real end # Singular Value Decomposition # ---------------------------- @testset "Full SVD ($T)" for T in elts - A = randn(T, 5, 4, 3, 2) - labels_A = (:a, :b, :c, :d) - labels_U = (:b, :a) - labels_Vᴴ = (:d, :c) - - Acopy = deepcopy(A) - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=true) - @test A == Acopy # should not have altered initial array - US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) - @test A ≈ A′ - @test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary - @test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary - - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=true) - @test A == Acopy # should not have altered initial array - US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) - @test A ≈ A′ - @test size(Vᴴ, 1) == 1 - - U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=true) - @test A == Acopy # should not have altered initial array - US, labels_US = contract(U, (:u,), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) - @test A ≈ A′ - @test size(U, 2) == 1 + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_U = (:b, :a) + labels_Vᴴ = (:d, :c) + + Acopy = deepcopy(A) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test A ≈ A′ + @test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary + @test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary + + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + @test A ≈ A′ + @test size(Vᴴ, 1) == 1 + + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (:u,), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + @test A ≈ A′ + @test size(U, 2) == 1 end @testset "Compact SVD ($T)" for T in elts - A = randn(T, 5, 4, 3, 2) - labels_A = (:a, :b, :c, :d) - labels_U = (:b, :a) - labels_Vᴴ = (:d, :c) - - Acopy = deepcopy(A) - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=false) - @test A == Acopy # should not have altered initial array - US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) - @test A ≈ A′ - k = min(size(S)...) - @test size(U, 3) == k == size(Vᴴ, 1) - - Svals = @constinferred svdvals(A, labels_A, labels_U, labels_Vᴴ) - @test Svals ≈ diag(S) - - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=false) - @test A == Acopy # should not have altered initial array - US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) - @test A ≈ A′ - @test size(U, ndims(U)) == 1 == size(Vᴴ, 1) - - U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=false) - @test A == Acopy # should not have altered initial array - US, labels_US = contract(U, (:u,), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) - @test A ≈ A′ - @test size(U, 1) == 1 == size(Vᴴ, 1) + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_U = (:b, :a) + labels_Vᴴ = (:d, :c) + + Acopy = deepcopy(A) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test A ≈ A′ + k = min(size(S)...) + @test size(U, 3) == k == size(Vᴴ, 1) + + Svals = @constinferred svdvals(A, labels_A, labels_U, labels_Vᴴ) + @test Svals ≈ diag(S) + + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + @test A ≈ A′ + @test size(U, ndims(U)) == 1 == size(Vᴴ, 1) + + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (:u,), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + @test A ≈ A′ + @test size(U, 1) == 1 == size(Vᴴ, 1) end @testset "Truncated SVD ($T)" for T in elts - A = randn(T, 5, 4, 3, 2) - labels_A = (:a, :b, :c, :d) - labels_U = (:b, :a) - labels_Vᴴ = (:d, :c) - - # test truncated SVD - Acopy = deepcopy(A) - _, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ) - - trunc = truncrank(size(S_untrunc, 1) - 1) - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; trunc) - - @test A == Acopy # should not have altered initial array - US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) - @test norm(A - A′) ≈ S_untrunc[end] - @test size(S, 1) == size(S_untrunc, 1) - 1 + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_U = (:b, :a) + labels_Vᴴ = (:d, :c) + + # test truncated SVD + Acopy = deepcopy(A) + _, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ) + + trunc = truncrank(size(S_untrunc, 1) - 1) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; trunc) + + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test norm(A - A′) ≈ S_untrunc[end] + @test size(S, 1) == size(S_untrunc, 1) - 1 end @testset "Nullspace ($T)" for T in elts - A = randn(T, 5, 4, 3, 2) - labels_A = (:a, :b, :c, :d) - labels_codomain = (:b, :a) - labels_domain = (:d, :c) - - Acopy = deepcopy(A) - N = @constinferred left_null(A, labels_A, labels_codomain, labels_domain) - @test A == Acopy # should not have altered initial array - # N^ba_n' * A^ba_dc = 0 - NA = contract((:n, labels_domain...), conj(N), (labels_codomain..., :n), A, labels_A) - @test norm(NA) ≈ 0 atol = 1e-14 - NN = contract((:n, :n′), conj(N), (labels_codomain..., :n), N, (labels_codomain..., :n′)) - @test NN ≈ LinearAlgebra.I - - Nᴴ = @constinferred right_null(A, labels_A, labels_codomain, labels_domain) - @test A == Acopy # should not have altered initial array - # A^ba_dc * N^dc_n' = 0 - AN = contract((labels_codomain..., :n), A, labels_A, conj(Nᴴ), (:n, labels_domain...)) - @test norm(AN) ≈ 0 atol = 1e-14 - NN = contract((:n, :n′), Nᴴ, (:n, labels_domain...), Nᴴ, (:n′, labels_domain...)) + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_codomain = (:b, :a) + labels_domain = (:d, :c) + + Acopy = deepcopy(A) + N = @constinferred left_null(A, labels_A, labels_codomain, labels_domain) + @test A == Acopy # should not have altered initial array + # N^ba_n' * A^ba_dc = 0 + NA = contract((:n, labels_domain...), conj(N), (labels_codomain..., :n), A, labels_A) + @test norm(NA) ≈ 0 atol = 1.0e-14 + NN = contract((:n, :n′), conj(N), (labels_codomain..., :n), N, (labels_codomain..., :n′)) + @test NN ≈ LinearAlgebra.I + + Nᴴ = @constinferred right_null(A, labels_A, labels_codomain, labels_domain) + @test A == Acopy # should not have altered initial array + # A^ba_dc * N^dc_n' = 0 + AN = contract((labels_codomain..., :n), A, labels_A, conj(Nᴴ), (:n, labels_domain...)) + @test norm(AN) ≈ 0 atol = 1.0e-14 + NN = contract((:n, :n′), Nᴴ, (:n, labels_domain...), Nᴴ, (:n′, labels_domain...)) end @testset "Left polar ($T)" for T in elts - A = randn(T, 2, 2, 2, 2) - labels_A = (:a, :b, :c, :d) - labels_W = (:b, :a) - labels_P = (:d, :c) - - Acopy = deepcopy(A) - for (W, P) in ( - left_polar(A, labels_A, labels_W, labels_P), - polar(A, labels_A, labels_W, labels_P; side=:left), - polar(A, labels_A, labels_W, labels_P), - ) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ - @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) - end + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_W = (:b, :a) + labels_P = (:d, :c) + + Acopy = deepcopy(A) + for (W, P) in ( + left_polar(A, labels_A, labels_W, labels_P), + polar(A, labels_A, labels_W, labels_P; side = :left), + polar(A, labels_A, labels_W, labels_P), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Right polar ($T)" for T in elts - A = randn(T, 2, 2, 2, 2) - labels_A = (:a, :b, :c, :d) - labels_P = (:b, :a) - labels_W = (:d, :c) - - Acopy = deepcopy(A) - for (P, W) in ( - right_polar(A, labels_A, labels_P, labels_W), - polar(A, labels_A, labels_P, labels_W; side=:right), - ) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ - @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) - end + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_P = (:b, :a) + labels_W = (:d, :c) + + Acopy = deepcopy(A) + for (P, W) in ( + right_polar(A, labels_A, labels_P, labels_W), + polar(A, labels_A, labels_P, labels_W; side = :right), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Left orth ($T)" for T in elts - A = randn(T, 2, 2, 2, 2) - labels_A = (:a, :b, :c, :d) - labels_W = (:b, :a) - labels_P = (:d, :c) - - Acopy = deepcopy(A) - for (W, P) in ( - left_orth(A, labels_A, labels_W, labels_P), - orth(A, labels_A, labels_W, labels_P; side=:left), - orth(A, labels_A, labels_W, labels_P), - ) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ - @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) - end + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_W = (:b, :a) + labels_P = (:d, :c) + + Acopy = deepcopy(A) + for (W, P) in ( + left_orth(A, labels_A, labels_W, labels_P), + orth(A, labels_A, labels_W, labels_P; side = :left), + orth(A, labels_A, labels_W, labels_P), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Right orth ($T)" for T in elts - A = randn(T, 2, 2, 2, 2) - labels_A = (:a, :b, :c, :d) - labels_P = (:b, :a) - labels_W = (:d, :c) - - Acopy = deepcopy(A) - for (P, W) in ( - right_orth(A, labels_A, labels_P, labels_W), - orth(A, labels_A, labels_P, labels_W; side=:right), - ) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ - @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) - end + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_P = (:b, :a) + labels_W = (:d, :c) + + Acopy = deepcopy(A) + for (P, W) in ( + right_orth(A, labels_A, labels_P, labels_W), + orth(A, labels_A, labels_P, labels_W; side = :right), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "factorize ($T)" for T in elts - A = randn(T, 2, 2, 2, 2) - labels_A = (:a, :b, :c, :d) - labels_X = (:b, :a) - labels_Y = (:d, :c) - - Acopy = deepcopy(A) - for orth in (:left, :right) - X, Y = factorize(A, labels_A, labels_X, labels_Y; orth) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...)) - @test A ≈ A′ - @test size(X, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) - end + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_X = (:b, :a) + labels_Y = (:d, :c) + + Acopy = deepcopy(A) + for orth in (:left, :right) + X, Y = factorize(A, labels_A, labels_X, labels_Y; orth) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...)) + @test A ≈ A′ + @test size(X, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 7feefa3..1123120 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -7,285 +7,285 @@ using Test: @test, @testset elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "TensorAlgebra.MatrixAlgebra (elt=$elt)" for elt in elts - @testset "Factorizations" begin - rng = StableRNG(123) - A = randn(rng, elt, 3, 2) - for positive in (false, true) - for (Q, R) in - (MatrixAlgebra.qr(A; positive), MatrixAlgebra.qr(A; full=false, positive)) - @test A ≈ Q * R - @test size(Q) == size(A) - @test size(R) == (size(A, 2), size(A, 2)) - @test Q' * Q ≈ I - @test Q * Q' ≉ I - if positive - @test all(≥(0), real(diag(R))) - @test all(≈(0), imag(diag(R))) + @testset "Factorizations" begin + rng = StableRNG(123) + A = randn(rng, elt, 3, 2) + for positive in (false, true) + for (Q, R) in + (MatrixAlgebra.qr(A; positive), MatrixAlgebra.qr(A; full = false, positive)) + @test A ≈ Q * R + @test size(Q) == size(A) + @test size(R) == (size(A, 2), size(A, 2)) + @test Q' * Q ≈ I + @test Q * Q' ≉ I + if positive + @test all(≥(0), real(diag(R))) + @test all(≈(0), imag(diag(R))) + end + end end - end - end - A = randn(elt, 3, 2) - for positive in (false, true) - Q, R = MatrixAlgebra.qr(A; full=true, positive) - @test A ≈ Q * R - @test size(Q) == (size(A, 1), size(A, 1)) - @test size(R) == size(A) - @test Q' * Q ≈ I - @test Q * Q' ≈ I - if positive - @test all(≥(0), real(diag(R))) - @test all(≈(0), imag(diag(R))) - end - end + A = randn(elt, 3, 2) + for positive in (false, true) + Q, R = MatrixAlgebra.qr(A; full = true, positive) + @test A ≈ Q * R + @test size(Q) == (size(A, 1), size(A, 1)) + @test size(R) == size(A) + @test Q' * Q ≈ I + @test Q * Q' ≈ I + if positive + @test all(≥(0), real(diag(R))) + @test all(≈(0), imag(diag(R))) + end + end - A = randn(elt, 2, 3) - for positive in (false, true) - for (L, Q) in - (MatrixAlgebra.lq(A; positive), MatrixAlgebra.lq(A; full=false, positive)) - @test A ≈ L * Q - @test size(L) == (size(A, 1), size(A, 1)) - @test size(Q) == size(A) - @test Q * Q' ≈ I - @test Q' * Q ≉ I - if positive - @test all(≥(0), real(diag(L))) - @test all(≈(0), imag(diag(L))) + A = randn(elt, 2, 3) + for positive in (false, true) + for (L, Q) in + (MatrixAlgebra.lq(A; positive), MatrixAlgebra.lq(A; full = false, positive)) + @test A ≈ L * Q + @test size(L) == (size(A, 1), size(A, 1)) + @test size(Q) == size(A) + @test Q * Q' ≈ I + @test Q' * Q ≉ I + if positive + @test all(≥(0), real(diag(L))) + @test all(≈(0), imag(diag(L))) + end + end end - end - end - A = randn(elt, 3, 2) - for positive in (false, true) - L, Q = MatrixAlgebra.lq(A; full=true, positive) - @test A ≈ L * Q - @test size(L) == size(A) - @test size(Q) == (size(A, 2), size(A, 2)) - @test Q * Q' ≈ I - @test Q' * Q ≈ I - if positive - @test all(≥(0), real(diag(L))) - @test all(≈(0), imag(diag(L))) - end - end + A = randn(elt, 3, 2) + for positive in (false, true) + L, Q = MatrixAlgebra.lq(A; full = true, positive) + @test A ≈ L * Q + @test size(L) == size(A) + @test size(Q) == (size(A, 2), size(A, 2)) + @test Q * Q' ≈ I + @test Q' * Q ≈ I + if positive + @test all(≥(0), real(diag(L))) + @test all(≈(0), imag(diag(L))) + end + end - A = randn(elt, 3, 2) - for (W, C) in (MatrixAlgebra.orth(A), MatrixAlgebra.orth(A; side=:left)) - @test A ≈ W * C - @test size(W) == size(A) - @test size(C) == (size(A, 2), size(A, 2)) - @test W' * W ≈ I - @test W * W' ≉ I - end + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.orth(A), MatrixAlgebra.orth(A; side = :left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + end - A = randn(elt, 2, 3) - C, W = MatrixAlgebra.orth(A; side=:right) - @test A ≈ C * W - @test size(C) == (size(A, 1), size(A, 1)) - @test size(W) == size(A) - @test W * W' ≈ I - @test W' * W ≉ I + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.orth(A; side = :right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I - A = randn(elt, 3, 2) - for (W, P) in (MatrixAlgebra.polar(A), MatrixAlgebra.polar(A; side=:left)) - @test A ≈ W * P - @test size(W) == size(A) - @test size(P) == (size(A, 2), size(A, 2)) - @test W' * W ≈ I - @test W * W' ≉ I - @test isposdef(P) - end + A = randn(elt, 3, 2) + for (W, P) in (MatrixAlgebra.polar(A), MatrixAlgebra.polar(A; side = :left)) + @test A ≈ W * P + @test size(W) == size(A) + @test size(P) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + @test isposdef(P) + end - A = randn(elt, 2, 3) - P, W = MatrixAlgebra.polar(A; side=:right) - @test A ≈ P * W - @test size(P) == (size(A, 1), size(A, 1)) - @test size(W) == size(A) - @test W * W' ≈ I - @test W' * W ≉ I - @test isposdef(P) + A = randn(elt, 2, 3) + P, W = MatrixAlgebra.polar(A; side = :right) + @test A ≈ P * W + @test size(P) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I + @test isposdef(P) - A = randn(elt, 3, 2) - for (W, C) in (MatrixAlgebra.factorize(A), MatrixAlgebra.factorize(A; orth=:left)) - @test A ≈ W * C - @test size(W) == size(A) - @test size(C) == (size(A, 2), size(A, 2)) - @test W' * W ≈ I - @test W * W' ≉ I - end + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.factorize(A), MatrixAlgebra.factorize(A; orth = :left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + end - A = randn(elt, 2, 3) - C, W = MatrixAlgebra.factorize(A; orth=:right) - @test A ≈ C * W - @test size(C) == (size(A, 1), size(A, 1)) - @test size(W) == size(A) - @test W * W' ≈ I - @test W' * W ≉ I + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.factorize(A; orth = :right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I - A = randn(elt, 3, 3) - D, V = MatrixAlgebra.eigen(A) - @test A * V ≈ V * D - @test MatrixAlgebra.eigvals(A) ≈ diag(D) + A = randn(elt, 3, 3) + D, V = MatrixAlgebra.eigen(A) + @test A * V ≈ V * D + @test MatrixAlgebra.eigvals(A) ≈ diag(D) - A = randn(elt, 3, 2) - for (U, S, V) in (MatrixAlgebra.svd(A), MatrixAlgebra.svd(A; full=false)) - @test A ≈ U * S * V - @test size(U) == size(A) - @test size(S) == (size(A, 2), size(A, 2)) - @test size(V) == (size(A, 2), size(A, 2)) - @test U' * U ≈ I - @test U * U' ≉ I - @test V * V' ≈ I - @test V' * V ≈ I - @test MatrixAlgebra.svdvals(A) ≈ diag(S) - end + A = randn(elt, 3, 2) + for (U, S, V) in (MatrixAlgebra.svd(A), MatrixAlgebra.svd(A; full = false)) + @test A ≈ U * S * V + @test size(U) == size(A) + @test size(S) == (size(A, 2), size(A, 2)) + @test size(V) == (size(A, 2), size(A, 2)) + @test U' * U ≈ I + @test U * U' ≉ I + @test V * V' ≈ I + @test V' * V ≈ I + @test MatrixAlgebra.svdvals(A) ≈ diag(S) + end - A = randn(elt, 3, 2) - U, S, V = MatrixAlgebra.svd(A; full=true) - @test A ≈ U * S * V - @test size(U) == (size(A, 1), size(A, 1)) - @test size(S) == size(A) - @test size(V) == (size(A, 2), size(A, 2)) - @test U' * U ≈ I - @test U * U' ≈ I - @test V * V' ≈ I - @test V' * V ≈ I - @test MatrixAlgebra.svdvals(A) ≈ diag(S) - end - @testset "Truncate degenerate" begin - s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01]) - n = length(diag(s)) - rng = StableRNG(123) - u, _ = qr_compact(randn(rng, elt, n, n); positive=true) - v, _ = qr_compact(randn(rng, elt, n, n); positive=true) - a = u * s * v + A = randn(elt, 3, 2) + U, S, V = MatrixAlgebra.svd(A; full = true) + @test A ≈ U * S * V + @test size(U) == (size(A, 1), size(A, 1)) + @test size(S) == size(A) + @test size(V) == (size(A, 2), size(A, 2)) + @test U' * U ≈ I + @test U * U' ≈ I + @test V * V' ≈ I + @test V' * V ≈ I + @test MatrixAlgebra.svdvals(A) ≈ diag(S) + end + @testset "Truncate degenerate" begin + s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01]) + n = length(diag(s)) + rng = StableRNG(123) + u, _ = qr_compact(randn(rng, elt, n, n); positive = true) + v, _ = qr_compact(randn(rng, elt, n, n); positive = true) + a = u * s * v - ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(n); atol=0.1)) - @test size(ũ) == (n, n) - @test size(s̃) == (n, n) - @test size(ṽ) == (n, n) - @test ũ * s̃ * ṽ ≈ a + ũ, s̃, ṽ = svd_trunc(a; trunc = truncdegen(truncrank(n); atol = 0.1)) + @test size(ũ) == (n, n) + @test size(s̃) == (n, n) + @test size(ṽ) == (n, n) + @test ũ * s̃ * ṽ ≈ a - for kwargs in ( - (; atol=eps(real(elt))), - (; rtol=(√eps(real(elt)))), - (; atol=eps(real(elt)), rtol=(√eps(real(elt)))), - ) - ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(5); kwargs...)) - @test size(ũ) == (n, 4) - @test size(s̃) == (4, 4) - @test size(ṽ) == (4, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) - end + for kwargs in ( + (; atol = eps(real(elt))), + (; rtol = (√eps(real(elt)))), + (; atol = eps(real(elt)), rtol = (√eps(real(elt)))), + ) + ũ, s̃, ṽ = svd_trunc(a; trunc = truncdegen(truncrank(5); kwargs...)) + @test size(ũ) == (n, 4) + @test size(s̃) == (4, 4) + @test size(ṽ) == (4, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) + end - for kwargs in ( - (; atol=eps(real(elt))), - (; rtol=eps(real(elt))), - (; atol=eps(real(elt)), rtol=eps(real(elt))), - ) - ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(4); kwargs...)) - @test size(ũ) == (n, 4) - @test size(s̃) == (4, 4) - @test size(ṽ) == (4, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) - end + for kwargs in ( + (; atol = eps(real(elt))), + (; rtol = eps(real(elt))), + (; atol = eps(real(elt)), rtol = eps(real(elt))), + ) + ũ, s̃, ṽ = svd_trunc(a; trunc = truncdegen(truncrank(4); kwargs...)) + @test size(ũ) == (n, 4) + @test size(s̃) == (4, 4) + @test size(ṽ) == (4, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) + end - trunc = truncdegen(truncrank(3); atol=0.01 - √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 3) - @test size(s̃) == (3, 3) - @test size(ṽ) == (3, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); atol = 0.01 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); rtol=0.01/0.3 - √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 3) - @test size(s̃) == (3, 3) - @test size(ṽ) == (3, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); rtol = 0.01 / 0.3 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); atol=0.01 + √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); atol = 0.01 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); rtol=0.01/0.29 + √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); rtol = 0.01 / 0.29 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); atol=0.02 - √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); atol = 0.02 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); rtol=0.02/0.29 - √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); rtol = 0.02 / 0.29 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 1) - @test size(s̃) == (1, 1) - @test size(ṽ) == (1, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); atol = 0.03 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); rtol=0.03/0.29 + √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 1) - @test size(s̃) == (1, 1) - @test size(ṽ) == (1, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); rtol = 0.03 / 0.29 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); atol=0.01, rtol=0.03/0.29 + √eps(real(elt))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 1) - @test size(s̃) == (1, 1) - @test size(ṽ) == (1, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); atol = 0.01, rtol = 0.03 / 0.29 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt)), rtol=0.01/0.29) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 1) - @test size(s̃) == (1, 1) - @test size(ṽ) == (1, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); atol = 0.03 + √eps(real(elt)), rtol = 0.01 / 0.29) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); atol=(2 - 0.29) - √(eps(real(elt)))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 1) - @test size(s̃) == (1, 1) - @test size(ṽ) == (1, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); atol = (2 - 0.29) - √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 - √(eps(real(elt)))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 1) - @test size(s̃) == (1, 1) - @test size(ṽ) == (1, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + trunc = truncdegen(truncrank(3); rtol = (2 - 0.29) / 0.29 - √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) - trunc = truncdegen(truncrank(3); atol=(2 - 0.29) + √(eps(real(elt)))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 0) - @test size(s̃) == (0, 0) - @test size(ṽ) == (0, n) - @test norm(ũ * s̃ * ṽ) ≈ 0 + trunc = truncdegen(truncrank(3); atol = (2 - 0.29) + √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 - trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 + √(eps(real(elt)))) - ũ, s̃, ṽ = svd_trunc(a; trunc) - @test size(ũ) == (n, 0) - @test size(s̃) == (0, 0) - @test size(ṽ) == (0, n) - @test norm(ũ * s̃ * ṽ) ≈ 0 - end + trunc = truncdegen(truncrank(3); rtol = (2 - 0.29) / 0.29 + √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 + end end diff --git a/test/test_matrixfunctions.jl b/test/test_matrixfunctions.jl index 3e35e3c..7ca7daa 100644 --- a/test/test_matrixfunctions.jl +++ b/test/test_matrixfunctions.jl @@ -3,19 +3,19 @@ using TensorAlgebra: TensorAlgebra, biperm using Test: @test, @testset @testset "Matrix functions (eltype=$elt)" for elt in (Float32, ComplexF64) - for f in TensorAlgebra.MATRIX_FUNCTIONS - f == :cbrt && elt <: Complex && continue - f == :cbrt && VERSION < v"1.11-" && continue - @eval begin - rng = StableRNG(123) - a = randn(rng, $elt, (2, 2, 2, 2)) - for fa in ( - TensorAlgebra.$f(a, (:a, :b, :c, :d), (:c, :b), (:d, :a)), - TensorAlgebra.$f(a, biperm((3, 2, 4, 1), Val(2))), - ) - fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2)) - @test fa ≈ fa′ - end + for f in TensorAlgebra.MATRIX_FUNCTIONS + f == :cbrt && elt <: Complex && continue + f == :cbrt && VERSION < v"1.11-" && continue + @eval begin + rng = StableRNG(123) + a = randn(rng, $elt, (2, 2, 2, 2)) + for fa in ( + TensorAlgebra.$f(a, (:a, :b, :c, :d), (:c, :b), (:d, :a)), + TensorAlgebra.$f(a, biperm((3, 2, 4, 1), Val(2))), + ) + fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2)) + @test fa ≈ fa′ + end + end end - end end diff --git a/test/test_tensoroperations.jl b/test/test_tensoroperations.jl index 2ce04fd..e6c6592 100644 --- a/test/test_tensoroperations.jl +++ b/test/test_tensoroperations.jl @@ -3,121 +3,121 @@ using TensorOperations: @tensor, ncon, tensorcontract using TensorAlgebra: Matricize @testset "tensorcontract" begin - A = randn(Float64, (3, 20, 5, 3, 4)) - B = randn(Float64, (5, 6, 20, 3)) - C1 = @inferred tensorcontract( - A, ((1, 4, 5), (2, 3)), false, B, ((3, 1), (2, 4)), false, ((1, 5, 3, 2, 4), ()), 1.0 - ) - C2 = @inferred tensorcontract( - A, - ((1, 4, 5), (2, 3)), - false, - B, - ((3, 1), (2, 4)), - false, - ((1, 5, 3, 2, 4), ()), - 1.0, - Matricize(), - ) - @test C1 ≈ C2 + A = randn(Float64, (3, 20, 5, 3, 4)) + B = randn(Float64, (5, 6, 20, 3)) + C1 = @inferred tensorcontract( + A, ((1, 4, 5), (2, 3)), false, B, ((3, 1), (2, 4)), false, ((1, 5, 3, 2, 4), ()), 1.0 + ) + C2 = @inferred tensorcontract( + A, + ((1, 4, 5), (2, 3)), + false, + B, + ((3, 1), (2, 4)), + false, + ((1, 5, 3, 2, 4), ()), + 1.0, + Matricize(), + ) + @test C1 ≈ C2 end elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "tensor network examples ($T)" for T in elts - D1, D2, D3 = 30, 40, 20 - d1, d2 = 2, 3 - A1 = rand(T, D1, d1, D2) .- 1//2 - A2 = rand(T, D2, d2, D3) .- 1//2 - rhoL = rand(T, D1, D1) .- 1//2 - rhoR = rand(T, D3, D3) .- 1//2 - H = rand(T, d1, d2, d1, d2) .- 1//2 + D1, D2, D3 = 30, 40, 20 + d1, d2 = 2, 3 + A1 = rand(T, D1, d1, D2) .- 1 // 2 + A2 = rand(T, D2, d2, D3) .- 1 // 2 + rhoL = rand(T, D1, D1) .- 1 // 2 + rhoR = rand(T, D3, D3) .- 1 // 2 + H = rand(T, d1, d2, d1, d2) .- 1 // 2 - @tensor HrA12[a, s1, s2, c] := - rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] - @tensor backend = Matricize() HrA12′[a, s1, s2, c] := - rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] + @tensor HrA12[a, s1, s2, c] := + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] + @tensor backend = Matricize() HrA12′[a, s1, s2, c] := + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] - @test HrA12 ≈ HrA12′ - @test HrA12 ≈ ncon( - [rhoL, H, A2, rhoR, A1], - [[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]]; - backend=Matricize(), - ) - E = @tensor rhoL[a', a] * - A1[a, s, b] * - A2[b, s', c] * - rhoR[c, c'] * - H[t, t', s, s'] * - conj(A1[a', t, b']) * - conj(A2[b', t', c']) - @test E ≈ @tensor backend = Matricize() rhoL[a', a] * - A1[a, s, b] * - A2[b, s', c] * - rhoR[c, c'] * - H[t, t', s, s'] * - conj(A1[a', t, b']) * - conj(A2[b', t', c']) + @test HrA12 ≈ HrA12′ + @test HrA12 ≈ ncon( + [rhoL, H, A2, rhoR, A1], + [[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]]; + backend = Matricize(), + ) + E = @tensor rhoL[a', a] * + A1[a, s, b] * + A2[b, s', c] * + rhoR[c, c'] * + H[t, t', s, s'] * + conj(A1[a', t, b']) * + conj(A2[b', t', c']) + @test E ≈ @tensor backend = Matricize() rhoL[a', a] * + A1[a, s, b] * + A2[b, s', c] * + rhoR[c, c'] * + H[t, t', s, s'] * + conj(A1[a', t, b']) * + conj(A2[b', t', c']) end function generate_random_network( - num_contracted_inds, num_open_inds, max_dim, max_ind_per_tensor -) - contracted_indices = repeat(collect(1:num_contracted_inds), 2) - open_indices = collect(1:num_open_inds) - dimensions = [ - repeat(rand(1:max_dim, num_contracted_inds), 2) - rand(1:max_dim, num_open_inds) - ] + num_contracted_inds, num_open_inds, max_dim, max_ind_per_tensor + ) + contracted_indices = repeat(collect(1:num_contracted_inds), 2) + open_indices = collect(1:num_open_inds) + dimensions = [ + repeat(rand(1:max_dim, num_contracted_inds), 2) + rand(1:max_dim, num_open_inds) + ] - sizes = Vector{Int64}[] - indices = Vector{Int64}[] + sizes = Vector{Int64}[] + indices = Vector{Int64}[] - while !isempty(contracted_indices) || !isempty(open_indices) - num_inds = rand( - 1:min(max_ind_per_tensor, length(contracted_indices) + length(open_indices)) - ) + while !isempty(contracted_indices) || !isempty(open_indices) + num_inds = rand( + 1:min(max_ind_per_tensor, length(contracted_indices) + length(open_indices)) + ) - cur_inds = Int64[] - cur_dims = Int64[] + cur_inds = Int64[] + cur_dims = Int64[] - for _ in 1:num_inds - curind_index = rand(1:(length(contracted_indices) + length(open_indices))) + for _ in 1:num_inds + curind_index = rand(1:(length(contracted_indices) + length(open_indices))) - if curind_index <= length(contracted_indices) - push!(cur_inds, contracted_indices[curind_index]) - push!(cur_dims, dimensions[curind_index]) - deleteat!(contracted_indices, curind_index) - deleteat!(dimensions, curind_index) - else - tind = curind_index - length(contracted_indices) - push!(cur_inds, -open_indices[tind]) - push!(cur_dims, dimensions[curind_index]) - deleteat!(open_indices, tind) - deleteat!(dimensions, curind_index) - end - end + if curind_index <= length(contracted_indices) + push!(cur_inds, contracted_indices[curind_index]) + push!(cur_dims, dimensions[curind_index]) + deleteat!(contracted_indices, curind_index) + deleteat!(dimensions, curind_index) + else + tind = curind_index - length(contracted_indices) + push!(cur_inds, -open_indices[tind]) + push!(cur_dims, dimensions[curind_index]) + deleteat!(open_indices, tind) + deleteat!(dimensions, curind_index) + end + end - push!(sizes, cur_dims) - push!(indices, cur_inds) - end - return sizes, indices + push!(sizes, cur_dims) + push!(indices, cur_inds) + end + return sizes, indices end @testset "random contractions" begin - MAX_CONTRACTED_INDICES = 10 - MAX_OPEN_INDICES = 5 - MAX_DIM = 5 - MAX_IND_PER_TENS = 3 - NUM_TESTS = 10 + MAX_CONTRACTED_INDICES = 10 + MAX_OPEN_INDICES = 5 + MAX_DIM = 5 + MAX_IND_PER_TENS = 3 + NUM_TESTS = 10 - for _ in 1:NUM_TESTS - sizes, indices = generate_random_network( - rand(1:MAX_CONTRACTED_INDICES), rand(1:MAX_OPEN_INDICES), MAX_DIM, MAX_IND_PER_TENS - ) - tensors = map(splat(randn), sizes) - result1 = ncon(tensors, indices) - result2 = ncon(tensors, indices; backend=Matricize()) - @test result1 ≈ result2 - end + for _ in 1:NUM_TESTS + sizes, indices = generate_random_network( + rand(1:MAX_CONTRACTED_INDICES), rand(1:MAX_OPEN_INDICES), MAX_DIM, MAX_IND_PER_TENS + ) + tensors = map(splat(randn), sizes) + result1 = ncon(tensors, indices) + result2 = ncon(tensors, indices; backend = Matricize()) + @test result1 ≈ result2 + end end