From c688aabf0ff599e830d196d01b0fe86b9204fc24 Mon Sep 17 00:00:00 2001 From: mtfishman <7855256+mtfishman@users.noreply.github.com> Date: Wed, 1 Oct 2025 01:08:24 +0000 Subject: [PATCH 1/2] Format .jl files (Runic) --- Project.toml | 2 +- docs/make.jl | 22 +- docs/make_index.jl | 16 +- docs/make_readme.jl | 16 +- examples/README.jl | 2 +- .../KroneckerArraysBlockSparseArraysExt.jl | 68 +- .../KroneckerArraysTensorAlgebraExt.jl | 34 +- .../KroneckerArraysTensorProductsExt.jl | 6 +- src/cartesianproduct.jl | 164 ++-- src/fillarrays.jl | 98 +-- src/kroneckerarray.jl | 580 ++++++------- src/linearalgebra.jl | 214 ++--- src/matrixalgebrakit.jl | 430 +++++----- test/runtests.jl | 80 +- test/test_aqua.jl | 2 +- test/test_basics.jl | 466 +++++----- test/test_blocksparsearrays.jl | 812 +++++++++--------- test/test_delta.jl | 762 ++++++++-------- test/test_matrixalgebrakit.jl | 238 ++--- test/test_matrixalgebrakit_delta.jl | 486 +++++------ test/test_tensoralgebra.jl | 8 +- test/test_tensorproducts.jl | 14 +- 22 files changed, 2261 insertions(+), 2259 deletions(-) diff --git a/Project.toml b/Project.toml index 5cad487..b50da7a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.2.3" +version = "0.2.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/make.jl b/docs/make.jl index ed07bce..1cdbd46 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,23 +2,23 @@ using KroneckerArrays: KroneckerArrays using Documenter: Documenter, DocMeta, deploydocs, makedocs DocMeta.setdocmeta!( - KroneckerArrays, :DocTestSetup, :(using KroneckerArrays); recursive=true + KroneckerArrays, :DocTestSetup, :(using KroneckerArrays); recursive = true ) include("make_index.jl") makedocs(; - modules=[KroneckerArrays], - authors="ITensor developers and contributors", - sitename="KroneckerArrays.jl", - format=Documenter.HTML(; - canonical="https://itensor.github.io/KroneckerArrays.jl", - edit_link="main", - assets=["assets/favicon.ico", "assets/extras.css"], - ), - pages=["Home" => "index.md", "Reference" => "reference.md"], + modules = [KroneckerArrays], + authors = "ITensor developers and contributors", + sitename = "KroneckerArrays.jl", + format = Documenter.HTML(; + canonical = "https://itensor.github.io/KroneckerArrays.jl", + edit_link = "main", + assets = ["assets/favicon.ico", "assets/extras.css"], + ), + pages = ["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; - repo="github.com/ITensor/KroneckerArrays.jl", devbranch="main", push_preview=true + repo = "github.com/ITensor/KroneckerArrays.jl", devbranch = "main", push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index a6707dd..33ef14e 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -2,20 +2,20 @@ using Literate: Literate using KroneckerArrays: KroneckerArrays function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ ```@raw html Flatiron Center for Computational Quantum Physics logo. Flatiron Center for Computational Quantum Physics logo. ``` """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(KroneckerArrays), "examples", "README.jl"), - joinpath(pkgdir(KroneckerArrays), "docs", "src"); - flavor=Literate.DocumenterFlavor(), - name="index", - postprocess=ccq_logo, + joinpath(pkgdir(KroneckerArrays), "examples", "README.jl"), + joinpath(pkgdir(KroneckerArrays), "docs", "src"); + flavor = Literate.DocumenterFlavor(), + name = "index", + postprocess = ccq_logo, ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index 46bf81c..fb82379 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -2,20 +2,20 @@ using Literate: Literate using KroneckerArrays: KroneckerArrays function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ Flatiron Center for Computational Quantum Physics logo. """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(KroneckerArrays), "examples", "README.jl"), - joinpath(pkgdir(KroneckerArrays)); - flavor=Literate.CommonMarkFlavor(), - name="README", - postprocess=ccq_logo, + joinpath(pkgdir(KroneckerArrays), "examples", "README.jl"), + joinpath(pkgdir(KroneckerArrays)); + flavor = Literate.CommonMarkFlavor(), + name = "README", + postprocess = ccq_logo, ) diff --git a/examples/README.jl b/examples/README.jl index 1d8ca88..f8d397a 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -1,5 +1,5 @@ # # KroneckerArrays.jl -# +# # [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/KroneckerArrays.jl/stable/) # [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/KroneckerArrays.jl/dev/) # [![Build Status](https://github.com/ITensor/KroneckerArrays.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/KroneckerArrays.jl/actions/workflows/Tests.yml?query=branch%3Amain) diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 6638acd..136f8f1 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -4,23 +4,23 @@ using BlockArrays: Block using BlockSparseArrays: BlockIndexVector, GenericBlockIndex using KroneckerArrays: CartesianPair, CartesianProduct function Base.getindex( - b::Block, - I1::Union{CartesianPair,CartesianProduct}, - Irest::Union{CartesianPair,CartesianProduct}..., -) - return GenericBlockIndex(b, (I1, Irest...)) + b::Block, + I1::Union{CartesianPair, CartesianProduct}, + Irest::Union{CartesianPair, CartesianProduct}..., + ) + return GenericBlockIndex(b, (I1, Irest...)) end function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...) - return BlockIndexVector(b, (I1, Irest...)) + return BlockIndexVector(b, (I1, Irest...)) end using BlockSparseArrays: BlockSparseArrays, blockrange using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair}) - return blockrange(map(cartesianrange, bs)) + return blockrange(map(cartesianrange, bs)) end function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) - return blockrange(map(cartesianrange, bs)) + return blockrange(map(cartesianrange, bs)) end using BlockArrays: BlockArrays, mortar @@ -31,7 +31,7 @@ using KroneckerArrays: CartesianProductUnitRange # This is helpful when indexing `BlockUnitRange`, for example: # https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.7.1/src/blockaxis.jl#L540-L547 function BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange}) - return mortar(blocks, (blockrange(map(Base.axes1, blocks)),)) + return mortar(blocks, (blockrange(map(Base.axes1, blocks)),)) end using BlockArrays: AbstractBlockedUnitRange @@ -39,48 +39,48 @@ using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2, isactive function KroneckerArrays.arg1(r::AbstractBlockedUnitRange) - return mortar_axis(arg1.(eachblockaxis(r))) + return mortar_axis(arg1.(eachblockaxis(r))) end function KroneckerArrays.arg2(r::AbstractBlockedUnitRange) - return mortar_axis(arg2.(eachblockaxis(r))) + return mortar_axis(arg2.(eachblockaxis(r))) end function block_axes( - ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Vararg{Block{1},N} -) where {N} - return ntuple(N) do d - return only(axes(ax[d][I[d]])) - end + ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Vararg{Block{1}, N} + ) where {N} + return ntuple(N) do d + return only(axes(ax[d][I[d]])) + end end -function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) where {N} - return block_axes(ax, Tuple(I)...) +function block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Block{N}) where {N} + return block_axes(ax, Tuple(I)...) end using DiagonalArrays: ShapeInitializer ## TODO: Is this needed? function Base.getindex( - a::ZeroBlocks{N,KroneckerArray{T,N,A1,A2}}, I::Vararg{Int,N} -) where {T,N,A1<:AbstractArray{T,N},A2<:AbstractArray{T,N}} - ax_a1 = map(arg1, a.parentaxes) - ax_a2 = map(arg2, a.parentaxes) - block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I))) - block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I))) - # TODO: Is this a good definition? It is similar to - # the definition of `similar` and `adapt_structure`. - return if isactive(A1) == isactive(A2) - ZeroBlocks{N,A1}(ax_a1)[I...] ⊗ ZeroBlocks{N,A2}(ax_a2)[I...] - elseif isactive(A1) - ZeroBlocks{N,A1}(ax_a1)[I...] ⊗ A2(ShapeInitializer(), block_ax_a2) - elseif isactive(A2) - A1(ShapeInitializer(), block_ax_a1) ⊗ ZeroBlocks{N,A2}(ax_a2)[I...] - end + a::ZeroBlocks{N, KroneckerArray{T, N, A1, A2}}, I::Vararg{Int, N} + ) where {T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} + ax_a1 = map(arg1, a.parentaxes) + ax_a2 = map(arg2, a.parentaxes) + block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I))) + block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I))) + # TODO: Is this a good definition? It is similar to + # the definition of `similar` and `adapt_structure`. + return if isactive(A1) == isactive(A2) + ZeroBlocks{N, A1}(ax_a1)[I...] ⊗ ZeroBlocks{N, A2}(ax_a2)[I...] + elseif isactive(A1) + ZeroBlocks{N, A1}(ax_a1)[I...] ⊗ A2(ShapeInitializer(), block_ax_a2) + elseif isactive(A2) + A1(ShapeInitializer(), block_ax_a1) ⊗ ZeroBlocks{N, A2}(ax_a2)[I...] + end end using BlockSparseArrays: BlockSparseArrays using KroneckerArrays: KroneckerArrays, KroneckerVector function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I) - return KroneckerArrays.to_truncated_indices(values, I) + return KroneckerArrays.to_truncated_indices(values, I) end end diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl index 2969ea5..78362c2 100644 --- a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -2,41 +2,41 @@ module KroneckerArraysTensorAlgebraExt using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2 using TensorAlgebra: - TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize + TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize -struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle - a::A - b::B +struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle + a::A + b::B end KroneckerArrays.arg1(style::KroneckerFusion) = style.a KroneckerArrays.arg2(style::KroneckerFusion) = style.b function TensorAlgebra.FusionStyle(a::KroneckerArray) - return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) + return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) end function matricize_kronecker( - style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} -) - return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm) + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} + ) + return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm) end function TensorAlgebra.matricize( - style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} -) - return matricize_kronecker(style, a, biperm) + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} + ) + return matricize_kronecker(style, a, biperm) end # Fix ambiguity error. # TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this. using TensorAlgebra: BlockedTrivialPermutation, unmatricize function TensorAlgebra.matricize( - style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} -) - return matricize_kronecker(style, a, biperm) + style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} + ) + return matricize_kronecker(style, a, biperm) end function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax) - return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗ - unmatricize(arg2(style), arg2(a), arg2.(ax)) + return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗ + unmatricize(arg2(style), arg2(a), arg2.(ax)) end function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax) - return unmatricize_kronecker(style, a, ax) + return unmatricize_kronecker(style, a, ax) end end diff --git a/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl b/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl index f45cc37..4920d92 100644 --- a/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl +++ b/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl @@ -3,9 +3,9 @@ module KroneckerArraysTensorProductsExt using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct using TensorProducts: TensorProducts, tensor_product function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo) - prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2)) - range = tensor_product(unproduct(a1), unproduct(a2)) - return cartesianrange(prod, range) + prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2)) + range = tensor_product(unproduct(a1), unproduct(a2)) + return cartesianrange(prod, range) end end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index ace4871..1e5dd79 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -1,6 +1,6 @@ -struct CartesianPair{A1,A2} - arg1::A1 - arg2::A2 +struct CartesianPair{A1, A2} + arg1::A1 + arg2::A2 end arguments(a::CartesianPair) = (arg1(a), arg2(a)) arguments(a::CartesianPair, n::Int) = arguments(a)[n] @@ -11,14 +11,14 @@ arg2(a::CartesianPair) = getfield(a, :arg2) ×(a1, a2) = CartesianPair(a1, a2) function Base.show(io::IO, a::CartesianPair) - print(io, arg1(a), " × ", arg2(a)) - return nothing + print(io, arg1(a), " × ", arg2(a)) + return nothing end -struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <: - AbstractVector{CartesianPair{TA,TB}} - a::A - b::B +struct CartesianProduct{TA, TB, A <: AbstractVector{TA}, B <: AbstractVector{TB}} <: + AbstractVector{CartesianPair{TA, TB}} + a::A + b::B end arguments(a::CartesianProduct) = (arg1(a), arg2(a)) arguments(a::CartesianProduct, n::Int) = arguments(a)[n] @@ -29,12 +29,12 @@ arg2(a::CartesianProduct) = getfield(a, :b) Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a)) function Base.show(io::IO, a::CartesianProduct) - print(io, arg1(a), " × ", arg2(a)) - return nothing + print(io, arg1(a), " × ", arg2(a)) + return nothing end function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct) - show(io, a) - return nothing + show(io, a) + return nothing end ×(a1::AbstractVector, a2::AbstractVector) = CartesianProduct(a1, a2) @@ -42,52 +42,52 @@ Base.length(a::CartesianProduct) = length(arg1(a)) * length(arg2(a)) Base.size(a::CartesianProduct) = (length(a),) function Base.getindex(a::CartesianProduct, i::CartesianProduct) - return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] + return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] end function Base.getindex(a::CartesianProduct, i::CartesianPair) - return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] + return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] end function Base.getindex(a::CartesianProduct, i::Int) - I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i]) - return a[I[2] × I[1]] + I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i]) + return a[I[2] × I[1]] end -struct CartesianProductVector{T,P<:CartesianProduct,V<:AbstractVector{T}} <: - AbstractVector{T} - product::P - values::V +struct CartesianProductVector{T, P <: CartesianProduct, V <: AbstractVector{T}} <: + AbstractVector{T} + product::P + values::V end cartesianproduct(r::CartesianProductVector) = getfield(r, :product) unproduct(r::CartesianProductVector) = getfield(r, :values) Base.length(a::CartesianProductVector) = length(unproduct(a)) Base.size(a::CartesianProductVector) = (length(a),) function Base.axes(r::CartesianProductVector) - prod = cartesianproduct(r) - prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) - return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) + prod = cartesianproduct(r) + prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) + return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) end function Base.copy(a::CartesianProductVector) - return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a))) + return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a))) end function Base.getindex(r::CartesianProductVector, i::Integer) - return unproduct(r)[i] + return unproduct(r)[i] end function Base.show(io::IO, a::CartesianProductVector) - show(io, unproduct(a)) - return nothing + show(io, unproduct(a)) + return nothing end function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector) - show(io, mime, cartesianproduct(a)) - println(io) - show(io, mime, unproduct(a)) - return nothing + show(io, mime, cartesianproduct(a)) + println(io) + show(io, mime, unproduct(a)) + return nothing end -struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <: - AbstractUnitRange{T} - product::P - range::R +struct CartesianProductUnitRange{T, P <: CartesianProduct, R <: AbstractUnitRange{T}} <: + AbstractUnitRange{T} + product::P + range::R end Base.first(r::CartesianProductUnitRange) = first(r.range) Base.last(r::CartesianProductUnitRange) = last(r.range) @@ -99,98 +99,98 @@ arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange) - prod = cartesianproduct(a)[cartesianproduct(i)] - range = unproduct(a)[unproduct(i)] - return cartesianrange(prod, range) + prod = cartesianproduct(a)[cartesianproduct(i)] + range = unproduct(a)[unproduct(i)] + return cartesianrange(prod, range) end function Base.show(io::IO, a::CartesianProductUnitRange) - show(io, unproduct(a)) - return nothing + show(io, unproduct(a)) + return nothing end function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange) - show(io, mime, cartesianproduct(a)) - println(io) - show(io, mime, unproduct(a)) - return nothing + show(io, mime, cartesianproduct(a)) + println(io) + show(io, mime, unproduct(a)) + return nothing end function CartesianProductUnitRange(p::CartesianProduct) - return CartesianProductUnitRange(p, Base.OneTo(length(p))) + return CartesianProductUnitRange(p, Base.OneTo(length(p))) end function CartesianProductUnitRange(a1, a2) - return CartesianProductUnitRange(a1 × a2) + return CartesianProductUnitRange(a1 × a2) end to_product_indices(a::AbstractVector) = a to_product_indices(i::Integer) = Base.OneTo(i) cartesianrange(a1, a2) = cartesianrange(to_product_indices(a1) × to_product_indices(a2)) function cartesianrange(p::CartesianPair) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return cartesianrange(p′) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) + return cartesianrange(p′) end function cartesianrange(p::CartesianProduct) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return cartesianrange(p′, Base.OneTo(length(p′))) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) + return cartesianrange(p′, Base.OneTo(length(p′))) end function cartesianrange(p::CartesianPair, range::AbstractUnitRange) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return cartesianrange(p′, range) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) + return cartesianrange(p′, range) end function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return CartesianProductUnitRange(p′, range) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) + return CartesianProductUnitRange(p′, range) end function Base.axes(r::CartesianProductUnitRange) - prod = cartesianproduct(r) - prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) - return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) + prod = cartesianproduct(r) + prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) + return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) end function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair) - return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) + return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) end -const CartesianProductOneTo{T,P<:CartesianProduct,R<:Base.OneTo{T}} = CartesianProductUnitRange{ - T,P,R +const CartesianProductOneTo{T, P <: CartesianProduct, R <: Base.OneTo{T}} = CartesianProductUnitRange{ + T, P, R, } Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct) - prod = cartesianproduct(a) - prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)] - return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I)) + prod = cartesianproduct(a) + prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)] + return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I)) end # Reverse map from CartesianPair to linear index in the range. function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair) - i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds))) - return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]] + i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds))) + return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]] end using Base.Broadcast: DefaultArrayStyle for f in (:+, :-) - @eval begin - function Broadcast.broadcasted( - ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer - ) - return CartesianProductUnitRange(cartesianproduct(r), $f.(unproduct(r), x)) - end - function Broadcast.broadcasted( - ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange - ) - return CartesianProductUnitRange(cartesianproduct(r), $f.(x, unproduct(r))) + @eval begin + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer + ) + return CartesianProductUnitRange(cartesianproduct(r), $f.(unproduct(r), x)) + end + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange + ) + return CartesianProductUnitRange(cartesianproduct(r), $f.(x, unproduct(r))) + end end - end end using Base.Broadcast: axistype function Base.Broadcast.axistype( - r1::CartesianProductUnitRange, r2::CartesianProductUnitRange -) - prod = axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2)) - range = axistype(unproduct(r1), unproduct(r2)) - return cartesianrange(prod, range) + r1::CartesianProductUnitRange, r2::CartesianProductUnitRange + ) + prod = axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2)) + range = axistype(unproduct(r1), unproduct(r2)) + return cartesianrange(prod, range) end diff --git a/src/fillarrays.jl b/src/fillarrays.jl index 7a34083..0d8fd79 100644 --- a/src/fillarrays.jl +++ b/src/fillarrays.jl @@ -1,68 +1,68 @@ using FillArrays: FillArrays, Ones, Zeros function FillArrays.fillsimilar( - a::Zeros{T}, - ax::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) where {T} - return Zeros{T}(arg1.(ax)) ⊗ Zeros{T}(arg2.(ax)) + a::Zeros{T}, + ax::Tuple{ + CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, + }, + ) where {T} + return Zeros{T}(arg1.(ax)) ⊗ Zeros{T}(arg2.(ax)) end # Simplification rules similar to those for FillArrays.jl: # https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl using FillArrays: Zeros function Base.broadcasted( - style::KroneckerStyle, - ::typeof(+), - a::KroneckerArray, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types. - return a + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray, + b::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + ) + # TODO: Promote the element types. + return a end function Base.broadcasted( - style::KroneckerStyle, - ::typeof(+), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray, -) - # TODO: Promote the element types. - return b + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + b::KroneckerArray, + ) + # TODO: Promote the element types. + return b end function Base.broadcasted( - style::KroneckerStyle, - ::typeof(+), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types and axes. - return b + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + b::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + ) + # TODO: Promote the element types and axes. + return b end function Base.broadcasted( - style::KroneckerStyle, - ::typeof(-), - a::KroneckerArray, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types. - return a + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray, + b::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + ) + # TODO: Promote the element types. + return a end function Base.broadcasted( - style::KroneckerStyle, - ::typeof(-), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray, -) - # TODO: Promote the element types. - # TODO: Return `broadcasted(-, b)`. - return -b + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + b::KroneckerArray, + ) + # TODO: Promote the element types. + # TODO: Return `broadcasted(-, b)`. + return -b end function Base.broadcasted( - style::KroneckerStyle, - ::typeof(-), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types and axes. - return b + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + b::KroneckerArray{<:Any, <:Any, <:Zeros, <:Zeros}, + ) + # TODO: Promote the element types and axes. + return b end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index cb1f3fe..28328a7 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -1,13 +1,13 @@ function unwrap_array(a::AbstractArray) - p = parent(a) - p ≡ a && return a - return unwrap_array(p) + p = parent(a) + p ≡ a && return a + return unwrap_array(p) end isactive(a::AbstractArray) = ismutable(unwrap_array(a)) using TypeParameterAccessors: unwrap_array_type function isactive(arrayt::Type{<:AbstractArray}) - return ismutabletype(unwrap_array_type(arrayt)) + return ismutabletype(unwrap_array_type(arrayt)) end # Custom `_convert` works around the issue that @@ -16,80 +16,80 @@ end # https://github.com/JuliaLang/julia/pull/52487). # TODO: Delete once we drop support for Julia v1.10. function _convert(A::Type{<:AbstractArray}, a::AbstractArray) - return convert(A, a) + return convert(A, a) end using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag _construct(A::Type{<:Diagonal}, a::AbstractMatrix) = A(diag(a)) function _convert(A::Type{<:Diagonal}, a::AbstractMatrix) - LinearAlgebra.checksquare(a) - return isdiag(a) ? _construct(A, a) : throw(InexactError(:convert, A, a)) + LinearAlgebra.checksquare(a) + return isdiag(a) ? _construct(A, a) : throw(InexactError(:convert, A, a)) end -struct KroneckerArray{T,N,A1<:AbstractArray{T,N},A2<:AbstractArray{T,N}} <: - AbstractArray{T,N} - arg1::A1 - arg2::A2 +struct KroneckerArray{T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} <: + AbstractArray{T, N} + arg1::A1 + arg2::A2 end function KroneckerArray(a1::AbstractArray, a2::AbstractArray) - if ndims(a1) != ndims(a2) - throw( - ArgumentError("Kronecker product requires arrays of the same number of dimensions.") - ) - end - elt = promote_type(eltype(a1), eltype(a2)) - return _convert(AbstractArray{elt}, a1) ⊗ _convert(AbstractArray{elt}, a2) + if ndims(a1) != ndims(a2) + throw( + ArgumentError("Kronecker product requires arrays of the same number of dimensions.") + ) + end + elt = promote_type(eltype(a1), eltype(a2)) + return _convert(AbstractArray{elt}, a1) ⊗ _convert(AbstractArray{elt}, a2) end -const KroneckerMatrix{T,A1<:AbstractMatrix{T},A2<:AbstractMatrix{T}} = KroneckerArray{ - T,2,A1,A2 +const KroneckerMatrix{T, A1 <: AbstractMatrix{T}, A2 <: AbstractMatrix{T}} = KroneckerArray{ + T, 2, A1, A2, } -const KroneckerVector{T,A1<:AbstractVector{T},A2<:AbstractVector{T}} = KroneckerArray{ - T,1,A1,A2 +const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = KroneckerArray{ + T, 1, A1, A2, } @inline arg1(a::KroneckerArray) = getfield(a, :arg1) @inline arg2(a::KroneckerArray) = getfield(a, :arg2) function mutate_active_args!(f!, f, dest, src) - (isactive(arg1(dest)) || isactive(arg2(dest))) || - error("Can't mutate immutable KroneckerArray.") - if isactive(arg1(dest)) - f!(arg1(dest), arg1(src)) - else - arg1(dest) == f(arg1(src)) || error("Immutable arguments aren't equal.") - end - if isactive(arg2(dest)) - f!(arg2(dest), arg2(src)) - else - arg2(dest) == f(arg2(src)) || error("Immutable arguments aren't equal.") - end - return dest + (isactive(arg1(dest)) || isactive(arg2(dest))) || + error("Can't mutate immutable KroneckerArray.") + if isactive(arg1(dest)) + f!(arg1(dest), arg1(src)) + else + arg1(dest) == f(arg1(src)) || error("Immutable arguments aren't equal.") + end + if isactive(arg2(dest)) + f!(arg2(dest), arg2(src)) + else + arg2(dest) == f(arg2(src)) || error("Immutable arguments aren't equal.") + end + return dest end using Adapt: Adapt, adapt function Adapt.adapt_structure(to, a::KroneckerArray) - # TODO: Is this a good definition? It is similar to - # the definition of `similar`. - return if isactive(arg1(a)) == isactive(arg2(a)) - adapt(to, arg1(a)) ⊗ adapt(to, arg2(a)) - elseif isactive(arg1(a)) - adapt(to, arg1(a)) ⊗ arg2(a) - elseif isactive(arg2(a)) - arg1(a) ⊗ adapt(to, arg2(a)) - end + # TODO: Is this a good definition? It is similar to + # the definition of `similar`. + return if isactive(arg1(a)) == isactive(arg2(a)) + adapt(to, arg1(a)) ⊗ adapt(to, arg2(a)) + elseif isactive(arg1(a)) + adapt(to, arg1(a)) ⊗ arg2(a) + elseif isactive(arg2(a)) + arg1(a) ⊗ adapt(to, arg2(a)) + end end function Base.copy(a::KroneckerArray) - return copy(arg1(a)) ⊗ copy(arg2(a)) + return copy(arg1(a)) ⊗ copy(arg2(a)) end -function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}) where {N} - return mutate_active_args!(copyto!, copy, dest, src) +function Base.copyto!(dest::KroneckerArray{<:Any, N}, src::KroneckerArray{<:Any, N}) where {N} + return mutate_active_args!(copyto!, copy, dest, src) end function Base.convert( - ::Type{KroneckerArray{T,N,A1,A2}}, a::KroneckerArray -) where {T,N,A1,A2} - return _convert(A1, arg1(a)) ⊗ _convert(A2, arg2(a)) + ::Type{KroneckerArray{T, N, A1, A2}}, a::KroneckerArray + ) where {T, N, A1, A2} + return _convert(A1, arg1(a)) ⊗ _convert(A2, arg2(a)) end # Promote the element type if needed. @@ -98,111 +98,111 @@ end maybe_promot_eltype(a, elt) = eltype(a) <: elt ? a : elt.(a) function Base.similar( - a::KroneckerArray, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - # TODO: Is this a good definition? - return if isactive(arg1(a)) == isactive(arg2(a)) - similar(arg1(a), elt, arg1.(axs)) ⊗ similar(arg2(a), elt, arg2.(axs)) - elseif isactive(arg1(a)) - @assert arg2.(axs) == axes(arg2(a)) - similar(arg1(a), elt, arg1.(axs)) ⊗ maybe_promot_eltype(arg2(a), elt) - elseif isactive(arg2(a)) - @assert arg1.(axs) == axes(arg1(a)) - maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt, arg2.(axs)) - end + a::KroneckerArray, + elt::Type, + axs::Tuple{ + CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, + }, + ) + # TODO: Is this a good definition? + return if isactive(arg1(a)) == isactive(arg2(a)) + similar(arg1(a), elt, arg1.(axs)) ⊗ similar(arg2(a), elt, arg2.(axs)) + elseif isactive(arg1(a)) + @assert arg2.(axs) == axes(arg2(a)) + similar(arg1(a), elt, arg1.(axs)) ⊗ maybe_promot_eltype(arg2(a), elt) + elseif isactive(arg2(a)) + @assert arg1.(axs) == axes(arg1(a)) + maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt, arg2.(axs)) + end end function Base.similar(a::KroneckerArray, elt::Type) - # TODO: Is this a good definition? - return if isactive(arg1(a)) == isactive(arg2(a)) - similar(arg1(a), elt) ⊗ similar(arg2(a), elt) - elseif isactive(arg1(a)) - similar(arg1(a), elt) ⊗ maybe_promot_eltype(arg2(a), elt) - elseif isactive(arg2(a)) - maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt) - end + # TODO: Is this a good definition? + return if isactive(arg1(a)) == isactive(arg2(a)) + similar(arg1(a), elt) ⊗ similar(arg2(a), elt) + elseif isactive(arg1(a)) + similar(arg1(a), elt) ⊗ maybe_promot_eltype(arg2(a), elt) + elseif isactive(arg2(a)) + maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt) + end end function Base.similar(a::KroneckerArray) - # TODO: Is this a good definition? - return if isactive(arg1(a)) == isactive(arg2(a)) - similar(arg1(a)) ⊗ similar(arg2(a)) - elseif isactive(arg1(a)) - similar(arg1(a)) ⊗ arg2(a) - elseif isactive(arg2(a)) - arg1(a) ⊗ similar(arg2(a)) - end + # TODO: Is this a good definition? + return if isactive(arg1(a)) == isactive(arg2(a)) + similar(arg1(a)) ⊗ similar(arg2(a)) + elseif isactive(arg1(a)) + similar(arg1(a)) ⊗ arg2(a) + elseif isactive(arg2(a)) + arg1(a) ⊗ similar(arg2(a)) + end end function Base.similar( - a::AbstractArray, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - return similar(a, elt, map(arg1, axs)) ⊗ similar(a, elt, map(arg2, axs)) + a::AbstractArray, + elt::Type, + axs::Tuple{ + CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, + }, + ) + return similar(a, elt, map(arg1, axs)) ⊗ similar(a, elt, map(arg2, axs)) end function Base.similar( - arrayt::Type{<:KroneckerArray{<:Any,<:Any,A1,A2}}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) where {A1,A2} - return similar(A1, map(arg1, axs)) ⊗ similar(A2, map(arg2, axs)) + arrayt::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}, + axs::Tuple{ + CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, + }, + ) where {A1, A2} + return similar(A1, map(arg1, axs)) ⊗ similar(A2, map(arg2, axs)) end function Base.similar( - ::Type{<:KroneckerArray{<:Any,<:Any,A1,A2}}, sz::Tuple{Int,Vararg{Int}} -) where {A1,A2} - return similar(promote_type(A1, A2), sz) + ::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}, sz::Tuple{Int, Vararg{Int}} + ) where {A1, A2} + return similar(promote_type(A1, A2), sz) end function Base.similar( - arrayt::Type{<:AbstractArray}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - return similar(arrayt, map(arg1, axs)) ⊗ similar(arrayt, map(arg2, axs)) + arrayt::Type{<:AbstractArray}, + axs::Tuple{ + CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, + }, + ) + return similar(arrayt, map(arg1, axs)) ⊗ similar(arrayt, map(arg2, axs)) end function Base.permutedims(a::KroneckerArray, perm) - return permutedims(arg1(a), perm) ⊗ permutedims(arg2(a), perm) + return permutedims(arg1(a), perm) ⊗ permutedims(arg2(a), perm) end using DerivableInterfaces: DerivableInterfaces, permuteddims function DerivableInterfaces.permuteddims(a::KroneckerArray, perm) - return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm) + return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm) end function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm) - return mutate_active_args!( - (dest, src) -> permutedims!(dest, src, perm), Base.Fix2(permutedims, perm), dest, src - ) + return mutate_active_args!( + (dest, src) -> permutedims!(dest, src, perm), Base.Fix2(permutedims, perm), dest, src + ) end -function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}}) - return (t[1]..., flatten(Base.tail(t))...) +function flatten(t::Tuple{Tuple, Tuple, Vararg{Tuple}}) + return (t[1]..., flatten(Base.tail(t))...) end function flatten(t::Tuple{Tuple}) - return t[1] + return t[1] end flatten(::Tuple{}) = () function interleave(x::Tuple, y::Tuple) - length(x) == length(y) || throw(ArgumentError("Tuples must have the same length.")) - xy = ntuple(i -> (x[i], y[i]), length(x)) - return flatten(xy) + length(x) == length(y) || throw(ArgumentError("Tuples must have the same length.")) + xy = ntuple(i -> (x[i], y[i]), length(x)) + return flatten(xy) end # TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing: # https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 -function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N} - a′ = reshape(a, interleave(size(a), ntuple(one, N))) - b′ = reshape(b, interleave(ntuple(one, N), size(b))) - c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N))) - sz = reverse(ntuple(i -> size(a, i) * size(b, i), N)) - return permutedims(reshape(c′, sz), reverse(ntuple(identity, N))) +function kron_nd(a::AbstractArray{<:Any, N}, b::AbstractArray{<:Any, N}) where {N} + a′ = reshape(a, interleave(size(a), ntuple(one, N))) + b′ = reshape(b, interleave(ntuple(one, N), size(b))) + c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N))) + sz = reverse(ntuple(i -> size(a, i) * size(b, i), N)) + return permutedims(reshape(c′, sz), reverse(ntuple(identity, N))) end kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2) kron_nd(a1::AbstractVector, a2::AbstractVector) = kron(a1, a2) @@ -211,58 +211,58 @@ kron_nd(a1::AbstractVector, a2::AbstractVector) = kron(a1, a2) Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a))) function Base.zero(a::KroneckerArray) - return if isactive(arg1(a)) == isactive(arg2(a)) - # TODO: Maybe this should zero both arguments? - # This is how `a * false` would behave. - arg1(a) ⊗ zero(arg2(a)) - elseif isactive(arg1(a)) - zero(arg1(a)) ⊗ arg2(a) - elseif isactive(arg2(a)) - arg1(a) ⊗ zero(arg2(a)) - end + return if isactive(arg1(a)) == isactive(arg2(a)) + # TODO: Maybe this should zero both arguments? + # This is how `a * false` would behave. + arg1(a) ⊗ zero(arg2(a)) + elseif isactive(arg1(a)) + zero(arg1(a)) ⊗ arg2(a) + elseif isactive(arg2(a)) + arg1(a) ⊗ zero(arg2(a)) + end end using DerivableInterfaces: DerivableInterfaces, zero! function DerivableInterfaces.zero!(a::KroneckerArray) - (isactive(arg1(a)) || isactive(arg2(a))) || - error("Can't mutate immutable KroneckerArray.") - isactive(arg1(a)) && zero!(arg1(a)) - isactive(arg2(a)) && zero!(arg2(a)) - return a + (isactive(arg1(a)) || isactive(arg2(a))) || + error("Can't mutate immutable KroneckerArray.") + isactive(arg1(a)) && zero!(arg1(a)) + isactive(arg2(a)) && zero!(arg2(a)) + return a end -function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} - return convert(Array{T,N}, collect(a)) +function Base.Array{T, N}(a::KroneckerArray{S, N}) where {T, S, N} + return convert(Array{T, N}, collect(a)) end function Base.size(a::KroneckerArray) - return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a)) + return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a)) end function Base.axes(a::KroneckerArray) - return ntuple(ndims(a)) do dim - return CartesianProductUnitRange( - axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim)) - ) - end + return ntuple(ndims(a)) do dim + return CartesianProductUnitRange( + axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim)) + ) + end end arguments(a::KroneckerArray) = (arg1(a), arg2(a)) arguments(a::KroneckerArray, n::Int) = arguments(a)[n] argument_types(a::KroneckerArray) = argument_types(typeof(a)) -argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A1,A2}}) where {A1,A2} = (A1, A2) +argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2) function Base.print_array(io::IO, a::KroneckerArray) - Base.print_array(io, arg1(a)) - println(io, "\n ⊗") - Base.print_array(io, arg2(a)) - return nothing + Base.print_array(io, arg1(a)) + println(io, "\n ⊗") + Base.print_array(io, arg2(a)) + return nothing end function Base.show(io::IO, a::KroneckerArray) - show(io, arg1(a)) - print(io, " ⊗ ") - show(io, arg2(a)) - return nothing + show(io, arg1(a)) + print(io, " ⊗ ") + show(io, arg2(a)) + return nothing end ⊗(a1::AbstractArray, a2::AbstractArray) = KroneckerArray(a1, a2) @@ -271,232 +271,232 @@ end ⊗(a1::AbstractArray, a2::Number) = a1 * a2 function Base.getindex(a::KroneckerArray, i::Integer) - return a[CartesianIndices(a)[i]] + return a[CartesianIndices(a)[i]] end using GPUArraysCore: GPUArraysCore -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} - GPUArraysCore.assertscalar("getindex") - I′ = ntuple(Val(N)) do dim - return cartesianproduct(axes(a, dim))[I[dim]] - end - return a[I′...] +function Base.getindex(a::KroneckerArray{<:Any, N}, I::Vararg{Integer, N}) where {N} + GPUArraysCore.assertscalar("getindex") + I′ = ntuple(Val(N)) do dim + return cartesianproduct(axes(a, dim))[I[dim]] + end + return a[I′...] end # Indexing logic. function Base.to_indices( - a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg} -) - I1 = to_indices(arg1(a), arg1.(inds), arg1.(I)) - I2 = to_indices(arg2(a), arg2.(inds), arg2.(I)) - return I1 .× I2 + a::KroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg} + ) + I1 = to_indices(arg1(a), arg1.(inds), arg1.(I)) + I2 = to_indices(arg2(a), arg2.(inds), arg2.(I)) + return I1 .× I2 end function Base.getindex( - a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N} -) where {N} - I′ = to_indices(a, I) - return arg1(a)[arg1.(I′)...] ⊗ arg2(a)[arg2.(I′)...] + a::KroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N} + ) where {N} + I′ = to_indices(a, I) + return arg1(a)[arg1.(I′)...] ⊗ arg2(a)[arg2.(I′)...] end # Fix ambigiuity error. -Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[] +Base.getindex(a::KroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[] arg1(::Colon) = (:) arg2(::Colon) = (:) arg1(::Base.Slice) = (:) arg2(::Base.Slice) = (:) function Base.view( - a::KroneckerArray{<:Any,N}, - I::Vararg{Union{CartesianProduct,CartesianProductUnitRange,Base.Slice,Colon},N}, -) where {N} - return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) + a::KroneckerArray{<:Any, N}, + I::Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N}, + ) where {N} + return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) end -function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} - return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) +function Base.view(a::KroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N} + return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) end # Fix ambigiuity error. -Base.view(a::KroneckerArray{<:Any,0}) = view(arg1(a)) ⊗ view(arg2(a)) +Base.view(a::KroneckerArray{<:Any, 0}) = view(arg1(a)) ⊗ view(arg2(a)) function Base.:(==)(a::KroneckerArray, b::KroneckerArray) - return arg1(a) == arg1(b) && arg2(a) == arg2(b) + return arg1(a) == arg1(b) && arg2(a) == arg2(b) end function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...) - return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...) + return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...) end function Base.iszero(a::KroneckerArray) - return iszero(arg1(a)) || iszero(arg2(a)) + return iszero(arg1(a)) || iszero(arg2(a)) end function Base.isreal(a::KroneckerArray) - return isreal(arg1(a)) && isreal(arg2(a)) + return isreal(arg1(a)) && isreal(arg2(a)) end using DiagonalArrays: DiagonalArrays, diagonal function DiagonalArrays.diagonal(a::KroneckerArray) - return diagonal(arg1(a)) ⊗ diagonal(arg2(a)) + return diagonal(arg1(a)) ⊗ diagonal(arg2(a)) end Base.real(a::KroneckerArray{<:Real}) = a function Base.real(a::KroneckerArray) - if iszero(imag(arg1(a))) || iszero(imag(arg2(a))) - return real(arg1(a)) ⊗ real(arg2(a)) - elseif iszero(real(arg1(a))) || iszero(real(arg2(a))) - return -(imag(arg1(a)) ⊗ imag(arg2(a))) - end - return real(arg1(a)) ⊗ real(arg2(a)) - imag(arg1(a)) ⊗ imag(arg2(a)) + if iszero(imag(arg1(a))) || iszero(imag(arg2(a))) + return real(arg1(a)) ⊗ real(arg2(a)) + elseif iszero(real(arg1(a))) || iszero(real(arg2(a))) + return -(imag(arg1(a)) ⊗ imag(arg2(a))) + end + return real(arg1(a)) ⊗ real(arg2(a)) - imag(arg1(a)) ⊗ imag(arg2(a)) end Base.imag(a::KroneckerArray{<:Real}) = zero(a) function Base.imag(a::KroneckerArray) - if iszero(imag(arg1(a))) || iszero(real(arg2(a))) - return real(arg1(a)) ⊗ imag(arg2(a)) - elseif iszero(real(arg1(a))) || iszero(imag(arg2(a))) - return imag(arg1(a)) ⊗ real(arg2(a)) - end - return real(arg1(a)) ⊗ imag(arg2(a)) + imag(arg1(a)) ⊗ real(arg2(a)) + if iszero(imag(arg1(a))) || iszero(real(arg2(a))) + return real(arg1(a)) ⊗ imag(arg2(a)) + elseif iszero(real(arg1(a))) || iszero(imag(arg2(a))) + return imag(arg1(a)) ⊗ real(arg2(a)) + end + return real(arg1(a)) ⊗ imag(arg2(a)) + imag(arg1(a)) ⊗ real(arg2(a)) end for f in [:transpose, :adjoint, :inv] - @eval begin - function Base.$f(a::KroneckerArray) - return $f(arg1(a)) ⊗ $f(arg2(a)) + @eval begin + function Base.$f(a::KroneckerArray) + return $f(arg1(a)) ⊗ $f(arg2(a)) + end end - end end function Base.reshape( - a::KroneckerArray, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} -) - return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax)) + a::KroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}} + ) + return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax)) end using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted -struct KroneckerStyle{N,A1,A2} <: AbstractArrayStyle{N} end -arg1(::Type{<:KroneckerStyle{<:Any,A1}}) where {A1} = A1 +struct KroneckerStyle{N, A1, A2} <: AbstractArrayStyle{N} end +arg1(::Type{<:KroneckerStyle{<:Any, A1}}) where {A1} = A1 arg1(style::KroneckerStyle) = arg1(typeof(style)) -arg2(::Type{<:KroneckerStyle{<:Any,<:Any,A2}}) where {A2} = A2 +arg2(::Type{<:KroneckerStyle{<:Any, <:Any, A2}}) where {A2} = A2 arg2(style::KroneckerStyle) = arg2(typeof(style)) function KroneckerStyle{N}(a1::BroadcastStyle, a2::BroadcastStyle) where {N} - return KroneckerStyle{N,a1,a2}() + return KroneckerStyle{N, a1, a2}() end function KroneckerStyle(a1::AbstractArrayStyle{N}, a2::AbstractArrayStyle{N}) where {N} - return KroneckerStyle{N}(a1, a2) + return KroneckerStyle{N}(a1, a2) end -function KroneckerStyle{N,A1,A2}(v::Val{M}) where {N,A1,A2,M} - return KroneckerStyle{M,typeof(A1)(v),typeof(A2)(v)}() +function KroneckerStyle{N, A1, A2}(v::Val{M}) where {N, A1, A2, M} + return KroneckerStyle{M, typeof(A1)(v), typeof(A2)(v)}() end -function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A1,A2}}) where {N,A1,A2} - return KroneckerStyle{N}(BroadcastStyle(A1), BroadcastStyle(A2)) +function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any, N, A1, A2}}) where {N, A1, A2} + return KroneckerStyle{N}(BroadcastStyle(A1), BroadcastStyle(A2)) end function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N} - style_a = BroadcastStyle(arg1(style1), arg1(style2)) - (style_a isa Broadcast.Unknown) && return Broadcast.Unknown() - style_b = BroadcastStyle(arg2(style1), arg2(style2)) - (style_b isa Broadcast.Unknown) && return Broadcast.Unknown() - return KroneckerStyle{N}(style_a, style_b) + style_a = BroadcastStyle(arg1(style1), arg1(style2)) + (style_a isa Broadcast.Unknown) && return Broadcast.Unknown() + style_b = BroadcastStyle(arg2(style1), arg2(style2)) + (style_b isa Broadcast.Unknown) && return Broadcast.Unknown() + return KroneckerStyle{N}(style_a, style_b) end function Base.similar( - bc::Broadcasted{<:KroneckerStyle{N,A1,A2}}, elt::Type, ax -) where {N,A1,A2} - bc_a = Broadcasted(A1, bc.f, arg1.(bc.args), arg1.(ax)) - bc_b = Broadcasted(A2, bc.f, arg2.(bc.args), arg2.(ax)) - a = similar(bc_a, elt) - b = similar(bc_b, elt) - return a ⊗ b + bc::Broadcasted{<:KroneckerStyle{N, A1, A2}}, elt::Type, ax + ) where {N, A1, A2} + bc_a = Broadcasted(A1, bc.f, arg1.(bc.args), arg1.(ax)) + bc_b = Broadcasted(A2, bc.f, arg2.(bc.args), arg2.(ax)) + a = similar(bc_a, elt) + b = similar(bc_b, elt) + return a ⊗ b end function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) - return Broadcast.broadcast_preserving_zero_d(f, a1, a_rest...) + return Broadcast.broadcast_preserving_zero_d(f, a1, a_rest...) end function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...) - dest .= f.(a1, a_rest...) - return dest + dest .= f.(a1, a_rest...) + return dest end using MapBroadcast: MapBroadcast, LinearCombination, Summed function KroneckerBroadcast(a::Summed{<:KroneckerStyle}) - f = LinearCombination(a) - args = MapBroadcast.arguments(a) - arg1s = arg1.(args) - arg2s = arg2.(args) - arg1_isunique = allequal(arg1s) - arg2_isunique = allequal(arg2s) - (arg1_isunique || arg2_isunique) || - error("This operation doesn't preserve the Kronecker structure.") - broadcast_arg = if arg1_isunique && arg2_isunique - isactive(first(arg1s)) ? 1 : 2 - elseif arg1_isunique - 2 - elseif arg2_isunique - 1 - end - return if broadcast_arg == 1 - broadcasted(f, arg1s...) ⊗ first(arg2s) - elseif broadcast_arg == 2 - first(arg1s) ⊗ broadcasted(f, arg2s...) - end + f = LinearCombination(a) + args = MapBroadcast.arguments(a) + arg1s = arg1.(args) + arg2s = arg2.(args) + arg1_isunique = allequal(arg1s) + arg2_isunique = allequal(arg2s) + (arg1_isunique || arg2_isunique) || + error("This operation doesn't preserve the Kronecker structure.") + broadcast_arg = if arg1_isunique && arg2_isunique + isactive(first(arg1s)) ? 1 : 2 + elseif arg1_isunique + 2 + elseif arg2_isunique + 1 + end + return if broadcast_arg == 1 + broadcasted(f, arg1s...) ⊗ first(arg2s) + elseif broadcast_arg == 2 + first(arg1s) ⊗ broadcasted(f, arg2s...) + end end function Base.copy(a::Summed{<:KroneckerStyle}) - return copy(KroneckerBroadcast(a)) + return copy(KroneckerBroadcast(a)) end function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle}) - return copyto!(dest, KroneckerBroadcast(a)) + return copyto!(dest, KroneckerBroadcast(a)) end function Broadcast.broadcasted(::KroneckerStyle, f, as...) - return error("Arbitrary broadcasting not supported for KroneckerArray.") + return error("Arbitrary broadcasting not supported for KroneckerArray.") end # Linear operations. function Broadcast.broadcasted(::KroneckerStyle, ::typeof(+), a1, a2) - return Summed(a1) + Summed(a2) + return Summed(a1) + Summed(a2) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a1, a2) - return Summed(a1) - Summed(a2) + return Summed(a1) - Summed(a2) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), c::Number, a) - return c * Summed(a) + return c * Summed(a) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number) - return Summed(a) * c + return Summed(a) * c end # Fix ambiguity error. function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a::Number, b::Number) - return a * b + return a * b end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number) - return Summed(a) / c + return Summed(a) / c end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a) - return -Summed(a) + return -Summed(a) end # Rewrite rules to canonicalize broadcast expressions. -function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix1{typeof(*),<:Number}, a) - return broadcasted(style, *, f.x, a) +function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix1{typeof(*), <:Number}, a) + return broadcasted(style, *, f.x, a) end -function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(*),<:Number}, a) - return broadcasted(style, *, a, f.x) +function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(*), <:Number}, a) + return broadcasted(style, *, a, f.x) end -function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(/),<:Number}, a) - return broadcasted(style, /, a, f.x) +function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(/), <:Number}, a) + return broadcasted(style, /, a, f.x) end # Compatibility with MapBroadcast.jl. using MapBroadcast: MapBroadcast, MapFunction function Base.broadcasted( - style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, a -) - return broadcasted(style, *, f.args[1], a) + style::KroneckerStyle, f::MapFunction{typeof(*), <:Tuple{<:Number, MapBroadcast.Arg}}, a + ) + return broadcasted(style, *, f.args[1], a) end function Base.broadcasted( - style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, a -) - return broadcasted(style, *, a, f.args[2]) + style::KroneckerStyle, f::MapFunction{typeof(*), <:Tuple{MapBroadcast.Arg, <:Number}}, a + ) + return broadcasted(style, *, a, f.args[2]) end function Base.broadcasted( - style::KroneckerStyle, f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, a -) - return broadcasted(style, /, a, f.args[2]) + style::KroneckerStyle, f::MapFunction{typeof(/), <:Tuple{MapBroadcast.Arg, <:Number}}, a + ) + return broadcasted(style, /, a, f.args[2]) end # Use to determine the element type of KroneckerBroadcasted. _eltype(x) = eltype(x) @@ -506,9 +506,9 @@ using Base.Broadcast: broadcasted # Represents broadcast operations that can be applied Kronecker-wise, # i.e. independently to each argument of the Kronecker product. # Note that not all broadcast operations can be mapped to this. -struct KroneckerBroadcasted{A1,A2} - arg1::A1 - arg2::A2 +struct KroneckerBroadcasted{A1, A2} + arg1::A1 + arg2::A2 end @inline arg1(a::KroneckerBroadcasted) = getfield(a, :arg1) @inline arg2(a::KroneckerBroadcasted) = getfield(a, :arg2) @@ -520,34 +520,34 @@ Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) Broadcast.broadcastable(a::KroneckerBroadcasted) = a Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a)) function Base.copyto!(dest::KroneckerArray, src::KroneckerBroadcasted) - return mutate_active_args!(copyto!, copy, dest, src) + return mutate_active_args!(copyto!, copy, dest, src) end function Base.eltype(a::KroneckerBroadcasted) - a1 = arg1(a) - a2 = arg2(a) - return Base.promote_op(*, _eltype(a1), _eltype(a2)) + a1 = arg1(a) + a2 = arg2(a) + return Base.promote_op(*, _eltype(a1), _eltype(a2)) end function Base.axes(a::KroneckerBroadcasted) - ax1 = axes(arg1(a)) - ax2 = axes(arg2(a)) - return cartesianrange.(ax1 .× ax2) + ax1 = axes(arg1(a)) + ax2 = axes(arg2(a)) + return cartesianrange.(ax1 .× ax2) end function Base.BroadcastStyle( - ::Type{<:KroneckerBroadcasted{A1,A2}} -) where {StyleA1,StyleA2,A1<:Broadcasted{StyleA1},A2<:Broadcasted{StyleA2}} - @assert ndims(A1) == ndims(A2) - N = ndims(A1) - return KroneckerStyle{N}(StyleA1(), StyleA2()) + ::Type{<:KroneckerBroadcasted{A1, A2}} + ) where {StyleA1, StyleA2, A1 <: Broadcasted{StyleA1}, A2 <: Broadcasted{StyleA2}} + @assert ndims(A1) == ndims(A2) + N = ndims(A1) + return KroneckerStyle{N}(StyleA1(), StyleA2()) end # Operations that preserve the Kronecker structure. for f in [:identity, :conj] - @eval begin - function Broadcast.broadcasted( - ::KroneckerStyle{<:Any,A1,A2}, ::typeof($f), a - ) where {A1,A2} - return broadcasted(A1, $f, arg1(a)) ⊗ broadcasted(A2, $f, arg2(a)) + @eval begin + function Broadcast.broadcasted( + ::KroneckerStyle{<:Any, A1, A2}, ::typeof($f), a + ) where {A1, A2} + return broadcasted(A1, $f, arg1(a)) ⊗ broadcasted(A2, $f, arg2(a)) + end end - end end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index b8c0aaa..de67466 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -1,114 +1,114 @@ using DiagonalArrays: δ using LinearAlgebra: - LinearAlgebra, - Diagonal, - Eigen, - SVD, - det, - diag, - eigen, - eigvals, - lq, - mul!, - norm, - qr, - svd, - svdvals, - tr + LinearAlgebra, + Diagonal, + Eigen, + SVD, + det, + diag, + eigen, + eigvals, + lq, + mul!, + norm, + qr, + svd, + svdvals, + tr using LinearAlgebra: LinearAlgebra function KroneckerArray(J::LinearAlgebra.UniformScaling, ax::Tuple) - return δ(eltype(J), arg1.(ax)) ⊗ δ(eltype(J), arg2.(ax)) + return δ(eltype(J), arg1.(ax)) ⊗ δ(eltype(J), arg2.(ax)) end function Base.copyto!(a::KroneckerArray, J::LinearAlgebra.UniformScaling) - copyto!(a, KroneckerArray(J, axes(a))) - return a + copyto!(a, KroneckerArray(J, axes(a))) + return a end using LinearAlgebra: LinearAlgebra, pinv function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) - return pinv(arg1(a); kwargs...) ⊗ pinv(arg2(a); kwargs...) + return pinv(arg1(a); kwargs...) ⊗ pinv(arg2(a); kwargs...) end function LinearAlgebra.diag(a::KroneckerArray) - return copy(DiagonalArrays.diagview(a)) + return copy(DiagonalArrays.diagview(a)) end function Base.:*(a::KroneckerArray, b::KroneckerArray) - return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) + return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) end function LinearAlgebra.mul!( - c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number -) - iszero(β) || - iszero(c) || - throw( - ArgumentError( - "Can't multiple KroneckerArrays with nonzero β and nonzero destination." - ), + c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number ) - # TODO: Only perform in-place operation on the non-active argument(s). - mul!(arg1(c), arg1(a), arg1(b)) - mul!(arg2(c), arg2(a), arg2(b), α, β) - return c + iszero(β) || + iszero(c) || + throw( + ArgumentError( + "Can't multiple KroneckerArrays with nonzero β and nonzero destination." + ), + ) + # TODO: Only perform in-place operation on the non-active argument(s). + mul!(arg1(c), arg1(a), arg1(b)) + mul!(arg2(c), arg2(a), arg2(b), α, β) + return c end using LinearAlgebra: tr function LinearAlgebra.tr(a::KroneckerArray) - return tr(arg1(a)) * tr(arg2(a)) + return tr(arg1(a)) * tr(arg2(a)) end using LinearAlgebra: norm -function LinearAlgebra.norm(a::KroneckerArray, p::Int=2) - return norm(arg1(a), p) * norm(arg2(a), p) +function LinearAlgebra.norm(a::KroneckerArray, p::Int = 2) + return norm(arg1(a), p) * norm(arg2(a), p) end # Matrix functions const MATRIX_FUNCTIONS = [ - :exp, - :cis, - :log, - :sqrt, - :cbrt, - :cos, - :sin, - :tan, - :csc, - :sec, - :cot, - :cosh, - :sinh, - :tanh, - :csch, - :sech, - :coth, - :acos, - :asin, - :atan, - :acsc, - :asec, - :acot, - :acosh, - :asinh, - :atanh, - :acsch, - :asech, - :acoth, + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, ] for f in MATRIX_FUNCTIONS - @eval begin - function Base.$f(a::KroneckerArray) - return if isone(arg1(a)) - arg1(a) ⊗ $f(arg2(a)) - elseif isone(arg2(a)) - $f(arg1(a)) ⊗ arg2(a) - else - throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) - end + @eval begin + function Base.$f(a::KroneckerArray) + return if isone(arg1(a)) + arg1(a) ⊗ $f(arg2(a)) + elseif isone(arg2(a)) + $f(arg1(a)) ⊗ arg2(a) + else + throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) + end + end end - end end # `DiagonalArrays.issquare` and `DiagonalArrays.checksquare` are more general @@ -116,79 +116,79 @@ end # that the codomain and domain are dual of each other. using DiagonalArrays: DiagonalArrays, checksquare, issquare function DiagonalArrays.issquare(a::KroneckerArray) - return issquare(arg1(a)) && issquare(arg2(a)) + return issquare(arg1(a)) && issquare(arg2(a)) end using LinearAlgebra: det function LinearAlgebra.det(a::KroneckerArray) - checksquare(a) - return det(arg1(a)) ^ size(arg2(a), 1) * det(arg2(a)) ^ size(arg1(a), 1) + checksquare(a) + return det(arg1(a))^size(arg2(a), 1) * det(arg2(a))^size(arg1(a), 1) end function LinearAlgebra.svd(a::KroneckerArray) - F1 = svd(arg1(a)) - F2 = svd(arg2(a)) - return SVD(F1.U ⊗ F2.U, F1.S ⊗ F2.S, F1.Vt ⊗ F2.Vt) + F1 = svd(arg1(a)) + F2 = svd(arg2(a)) + return SVD(F1.U ⊗ F2.U, F1.S ⊗ F2.S, F1.Vt ⊗ F2.Vt) end function LinearAlgebra.svdvals(a::KroneckerArray) - return svdvals(arg1(a)) ⊗ svdvals(arg2(a)) + return svdvals(arg1(a)) ⊗ svdvals(arg2(a)) end function LinearAlgebra.eigen(a::KroneckerArray) - F1 = eigen(arg1(a)) - F2 = eigen(arg2(a)) - return Eigen(F1.values ⊗ F2.values, F1.vectors ⊗ F2.vectors) + F1 = eigen(arg1(a)) + F2 = eigen(arg2(a)) + return Eigen(F1.values ⊗ F2.values, F1.vectors ⊗ F2.vectors) end function LinearAlgebra.eigvals(a::KroneckerArray) - return eigvals(arg1(a)) ⊗ eigvals(arg2(a)) + return eigvals(arg1(a)) ⊗ eigvals(arg2(a)) end -struct KroneckerQ{A1,A2} - arg1::A1 - arg2::A2 +struct KroneckerQ{A1, A2} + arg1::A1 + arg2::A2 end @inline arg1(a::KroneckerQ) = getfield(a, :arg1) @inline arg2(a::KroneckerQ) = getfield(a, :arg2) function Base.:*(a::KroneckerQ, b::KroneckerQ) - return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) + return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) end function Base.:*(a1::KroneckerQ, a2::KroneckerArray) - return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) + return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) end function Base.:*(a1::KroneckerArray, a2::KroneckerQ) - return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) + return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) end function Base.adjoint(a::KroneckerQ) - return KroneckerQ(arg1(a)', arg2(a)') + return KroneckerQ(arg1(a)', arg2(a)') end -struct KroneckerQR{QQ,RR} - Q::QQ - R::RR +struct KroneckerQR{QQ, RR} + Q::QQ + R::RR end Base.iterate(F::KroneckerQR) = (F.Q, Val(:R)) Base.iterate(F::KroneckerQR, ::Val{:R}) = (F.R, Val(:done)) Base.iterate(F::KroneckerQR, ::Val{:done}) = nothing function ⊗(a1::LinearAlgebra.QRCompactWYQ, a2::LinearAlgebra.QRCompactWYQ) - return KroneckerQ(a1, a2) + return KroneckerQ(a1, a2) end function LinearAlgebra.qr(a::KroneckerArray) - Fa = qr(arg1(a)) - Fb = qr(arg2(a)) - return KroneckerQR(Fa.Q ⊗ Fb.Q, Fa.R ⊗ Fb.R) + Fa = qr(arg1(a)) + Fb = qr(arg2(a)) + return KroneckerQR(Fa.Q ⊗ Fb.Q, Fa.R ⊗ Fb.R) end -struct KroneckerLQ{LL,QQ} - L::LL - Q::QQ +struct KroneckerLQ{LL, QQ} + L::LL + Q::QQ end Base.iterate(F::KroneckerLQ) = (F.L, Val(:Q)) Base.iterate(F::KroneckerLQ, ::Val{:Q}) = (F.Q, Val(:done)) Base.iterate(F::KroneckerLQ, ::Val{:done}) = nothing function ⊗(a1::LinearAlgebra.LQPackedQ, a2::LinearAlgebra.LQPackedQ) - return KroneckerQ(a1, a2) + return KroneckerQ(a1, a2) end function LinearAlgebra.lq(a::KroneckerArray) - Fa = lq(arg1(a)) - Fb = lq(arg2(a)) - return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q) + Fa = lq(arg1(a)) + Fb = lq(arg2(a)) + return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q) end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 84eaa26..5bbca66 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -1,285 +1,285 @@ using MatrixAlgebraKit: - MatrixAlgebraKit, - AbstractAlgorithm, - TruncationStrategy, - default_eig_algorithm, - default_eigh_algorithm, - default_lq_algorithm, - default_polar_algorithm, - default_qr_algorithm, - default_svd_algorithm, - eig_full!, - eig_full, - eig_trunc!, - eig_trunc, - eig_vals!, - eig_vals, - eigh_full!, - eigh_full, - eigh_trunc!, - eigh_trunc, - eigh_vals!, - eigh_vals, - initialize_output, - left_null!, - left_null, - left_orth!, - left_orth, - left_polar!, - left_polar, - lq_compact!, - lq_compact, - lq_full!, - lq_full, - qr_compact!, - qr_compact, - qr_full!, - qr_full, - right_null!, - right_null, - right_orth!, - right_orth, - right_polar!, - right_polar, - svd_compact!, - svd_compact, - svd_full!, - svd_full, - svd_trunc!, - svd_trunc, - svd_vals!, - svd_vals, - truncate! + MatrixAlgebraKit, + AbstractAlgorithm, + TruncationStrategy, + default_eig_algorithm, + default_eigh_algorithm, + default_lq_algorithm, + default_polar_algorithm, + default_qr_algorithm, + default_svd_algorithm, + eig_full!, + eig_full, + eig_trunc!, + eig_trunc, + eig_vals!, + eig_vals, + eigh_full!, + eigh_full, + eigh_trunc!, + eigh_trunc, + eigh_vals!, + eigh_vals, + initialize_output, + left_null!, + left_null, + left_orth!, + left_orth, + left_polar!, + left_polar, + lq_compact!, + lq_compact, + lq_full!, + lq_full, + qr_compact!, + qr_compact, + qr_full!, + qr_full, + right_null!, + right_null, + right_orth!, + right_orth, + right_polar!, + right_polar, + svd_compact!, + svd_compact, + svd_full!, + svd_full, + svd_trunc!, + svd_trunc, + svd_vals!, + svd_vals, + truncate! using DiagonalArrays: DiagonalArrays, diagview function DiagonalArrays.diagview(a::KroneckerMatrix) - return diagview(arg1(a)) ⊗ diagview(arg2(a)) + return diagview(arg1(a)) ⊗ diagview(arg2(a)) end MatrixAlgebraKit.diagview(a::KroneckerMatrix) = diagview(a) -struct KroneckerAlgorithm{A1,A2} <: AbstractAlgorithm - arg1::A1 - arg2::A2 +struct KroneckerAlgorithm{A1, A2} <: AbstractAlgorithm + arg1::A1 + arg2::A2 end @inline arg1(alg::KroneckerAlgorithm) = getfield(alg, :arg1) @inline arg2(alg::KroneckerAlgorithm) = getfield(alg, :arg2) using MatrixAlgebraKit: - copy_input, - eig_full, - eig_vals, - eigh_full, - eigh_vals, - qr_compact, - qr_full, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full + copy_input, + eig_full, + eig_vals, + eigh_full, + eigh_vals, + qr_compact, + qr_full, + left_null, + left_orth, + left_polar, + lq_compact, + lq_full, + right_null, + right_orth, + right_polar, + svd_compact, + svd_full for f in [ - :eig_full, - :eigh_full, - :qr_compact, - :qr_full, - :left_polar, - :lq_compact, - :lq_full, - :right_polar, - :svd_compact, - :svd_full, -] - @eval begin - function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) - return copy_input($f, arg1(a)) ⊗ copy_input($f, arg2(a)) + :eig_full, + :eigh_full, + :qr_compact, + :qr_full, + :left_polar, + :lq_compact, + :lq_full, + :right_polar, + :svd_compact, + :svd_full, + ] + @eval begin + function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) + return copy_input($f, arg1(a)) ⊗ copy_input($f, arg2(a)) + end end - end end for f in [ - :default_eig_algorithm, - :default_eigh_algorithm, - :default_lq_algorithm, - :default_qr_algorithm, - :default_polar_algorithm, - :default_svd_algorithm, -] - @eval begin - function MatrixAlgebraKit.$f( - A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs... - ) - A1, A2 = argument_types(A) - return KroneckerAlgorithm( - $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) - ) + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, + ] + @eval begin + function MatrixAlgebraKit.$f( + A::Type{<:KroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs... + ) + A1, A2 = argument_types(A) + return KroneckerAlgorithm( + $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) + ) + end end - end end for f in [ - :eig_full, - :eigh_full, - :left_polar, - :lq_compact, - :lq_full, - :qr_compact, - :qr_full, - :right_polar, - :svd_compact, - :svd_full, -] - f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm - ) - return nothing + :eig_full, + :eigh_full, + :left_polar, + :lq_compact, + :lq_full, + :qr_compact, + :qr_full, + :right_polar, + :svd_compact, + :svd_full, + ] + f! = Symbol(f, :!) + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm + ) + return nothing + end + function MatrixAlgebraKit.$f!( + a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs... + ) + a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...) + a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...) + return a1 .⊗ a2 + end end - function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...) - a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...) - return a1 .⊗ a2 - end - end end for f in [:eig_vals, :eigh_vals, :svd_vals] - f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm - ) - return nothing + f! = Symbol(f, :!) + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm + ) + return nothing + end + function MatrixAlgebraKit.$f!( + a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs... + ) + a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...) + a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...) + return a1 ⊗ a2 + end end - function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...) - a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...) - return a1 ⊗ a2 - end - end end for f in [:left_orth, :right_orth] - f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f!), a::KroneckerMatrix) - return nothing + f! = Symbol(f, :!) + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f!), a::KroneckerMatrix) + return nothing + end + function MatrixAlgebraKit.$f!( + a::KroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... + ) + a1 = $f(arg1(a); kwargs..., kwargs1...) + a2 = $f(arg2(a); kwargs..., kwargs2...) + return a1 .⊗ a2 + end end - function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs... - ) - a1 = $f(arg1(a); kwargs..., kwargs1...) - a2 = $f(arg2(a); kwargs..., kwargs2...) - return a1 .⊗ a2 - end - end end for f in [:left_null, :right_null] - f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) - return nothing + f! = Symbol(f, :!) + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + return nothing + end + function MatrixAlgebraKit.$f!( + a::KroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... + ) + a1 = $f(arg1(a); kwargs..., kwargs1...) + a2 = $f(arg2(a); kwargs..., kwargs2...) + return a1 ⊗ a2 + end end - function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs... - ) - a1 = $f(arg1(a); kwargs..., kwargs1...) - a2 = $f(arg2(a); kwargs..., kwargs2...) - return a1 ⊗ a2 - end - end end # Truncation using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate! -struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy - strategy::T +struct KroneckerTruncationStrategy{T <: TruncationStrategy} <: TruncationStrategy + strategy::T end using FillArrays: OnesVector -const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} -const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} -const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} +const OnesKroneckerVector{T, A <: OnesVector{T}, B <: AbstractVector{T}} = KroneckerVector{T, A, B} +const KroneckerOnesVector{T, A <: AbstractVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B} +const OnesVectorOnesVector{T, A <: OnesVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B} axis(a) = only(axes(a)) # Convert indices determined with a generic call to `findtruncated` to indices # more suited for a KroneckerVector. function to_truncated_indices(values::OnesKroneckerVector, I) - prods = cartesianproduct(axis(values))[I] - I_id = only(to_indices(arg1(values), (:,))) - I_data = unique(arg2.(prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> arg2(x) == i, prods) == length(arg2(values)) - end - return I_id × I_data + prods = cartesianproduct(axis(values))[I] + I_id = only(to_indices(arg1(values), (:,))) + I_data = unique(arg2.(prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> arg2(x) == i, prods) == length(arg2(values)) + end + return I_id × I_data end function to_truncated_indices(values::KroneckerOnesVector, I) - #I = findtruncated(Vector(values), strategy.strategy) - prods = cartesianproduct(axis(values))[I] - I_data = unique(arg1.(prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> arg1(x) == i, prods) == length(arg2(values)) - end - I_id = only(to_indices(arg2(values), (:,))) - return I_data × I_id + #I = findtruncated(Vector(values), strategy.strategy) + prods = cartesianproduct(axis(values))[I] + I_data = unique(arg1.(prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> arg1(x) == i, prods) == length(arg2(values)) + end + I_id = only(to_indices(arg2(values), (:,))) + return I_data × I_id end # Fix ambiguity error. function to_truncated_indices(values::OnesVectorOnesVector, I) - return throw(ArgumentError("Not implemented")) + return throw(ArgumentError("Not implemented")) end function to_truncated_indices(values::KroneckerVector, I) - return throw(ArgumentError("Not implemented")) + return throw(ArgumentError("Not implemented")) end function MatrixAlgebraKit.findtruncated( - values::KroneckerVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - return to_truncated_indices(values, I) + values::KroneckerVector, strategy::KroneckerTruncationStrategy + ) + I = findtruncated(Vector(values), strategy.strategy) + return to_truncated_indices(values, I) end for f in [:eig_trunc!, :eigh_trunc!] - @eval begin - function MatrixAlgebraKit.truncate!( - ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy - ) - return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), DV::NTuple{2, KroneckerMatrix}, strategy::TruncationStrategy + ) + return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) + end + function MatrixAlgebraKit.truncate!( + ::typeof($f), (D, V)::NTuple{2, KroneckerMatrix}, strategy::KroneckerTruncationStrategy + ) + I = findtruncated(diagview(D), strategy) + return (D[I, I], V[(:) × (:), I]) + end end - function MatrixAlgebraKit.truncate!( - ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy - ) - I = findtruncated(diagview(D), strategy) - return (D[I, I], V[(:) × (:), I]) - end - end end function MatrixAlgebraKit.truncate!( - f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy -) - return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) + f::typeof(svd_trunc!), USVᴴ::NTuple{3, KroneckerMatrix}, strategy::TruncationStrategy + ) + return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) end function MatrixAlgebraKit.truncate!( - ::typeof(svd_trunc!), - (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, - strategy::KroneckerTruncationStrategy, -) - I = findtruncated(diagview(S), strategy) - return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) + ::typeof(svd_trunc!), + (U, S, Vᴴ)::NTuple{3, KroneckerMatrix}, + strategy::KroneckerTruncationStrategy, + ) + I = findtruncated(diagview(S), strategy) + return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) end diff --git a/test/runtests.jl b/test/runtests.jl index 98b2d2b..0008050 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,60 +6,62 @@ using Suppressor: Suppressor const pat = r"(?:--group=)(\w+)" arg_id = findfirst(contains(pat), ARGS) const GROUP = uppercase( - if isnothing(arg_id) - get(ENV, "GROUP", "ALL") - else - only(match(pat, ARGS[arg_id]).captures) - end, + if isnothing(arg_id) + get(ENV, "GROUP", "ALL") + else + only(match(pat, ARGS[arg_id]).captures) + end, ) "match files of the form `test_*.jl`, but exclude `*setup*.jl`" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") end "match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" function isexamplefile(fn) - return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") + return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @time begin - # tests in groups based on folder structure - for testgroup in filter(isdir, readdir(@__DIR__)) - if GROUP == "ALL" || GROUP == uppercase(testgroup) - groupdir = joinpath(@__DIR__, testgroup) - for file in filter(istestfile, readdir(groupdir)) - filename = joinpath(groupdir, file) - @eval @safetestset $file begin - include($filename) + # tests in groups based on folder structure + for testgroup in filter(isdir, readdir(@__DIR__)) + if GROUP == "ALL" || GROUP == uppercase(testgroup) + groupdir = joinpath(@__DIR__, testgroup) + for file in filter(istestfile, readdir(groupdir)) + filename = joinpath(groupdir, file) + @eval @safetestset $file begin + include($filename) + end + end end - end end - end - # single files in top folder - for file in filter(istestfile, readdir(@__DIR__)) - (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion - @eval @safetestset $file begin - include($file) + # single files in top folder + for file in filter(istestfile, readdir(@__DIR__)) + (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion + @eval @safetestset $file begin + include($file) + end end - end - # test examples - examplepath = joinpath(@__DIR__, "..", "examples") - for (root, _, files) in walkdir(examplepath) - contains(chopprefix(root, @__DIR__), "setup") && continue - for file in filter(isexamplefile, files) - filename = joinpath(root, file) - @eval begin - @safetestset $file begin - $(Expr( - :macrocall, - GlobalRef(Suppressor, Symbol("@suppress")), - LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), - )) + # test examples + examplepath = joinpath(@__DIR__, "..", "examples") + for (root, _, files) in walkdir(examplepath) + contains(chopprefix(root, @__DIR__), "setup") && continue + for file in filter(isexamplefile, files) + filename = joinpath(root, file) + @eval begin + @safetestset $file begin + $( + Expr( + :macrocall, + GlobalRef(Suppressor, Symbol("@suppress")), + LineNumberNode(@__LINE__, @__FILE__), + :(include($filename)), + ) + ) + end + end end - end end - end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 5727e26..f1486ae 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(KroneckerArrays) + Aqua.test_all(KroneckerArrays) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 5bfcf65..4b8c8e8 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,19 +5,19 @@ using DiagonalArrays: diagonal using GPUArraysCore: @allowscalar using JLArrays: JLArray using KroneckerArrays: - KroneckerArrays, - KroneckerArray, - KroneckerStyle, - CartesianProductUnitRange, - CartesianProductVector, - ⊗, - ×, - arg1, - arg2, - cartesianproduct, - cartesianrange, - kron_nd, - unproduct + KroneckerArrays, + KroneckerArray, + KroneckerStyle, + CartesianProductUnitRange, + CartesianProductVector, + ⊗, + ×, + arg1, + arg2, + cartesianproduct, + cartesianrange, + kron_nd, + unproduct using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset @@ -25,225 +25,225 @@ using TestExtras: @constinferred elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "KroneckerArrays (eltype=$elt)" for elt in elts - p = [1, 2] × [3, 4, 5] - @test length(p) == 6 - @test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5] - - r = @constinferred cartesianrange(2, 3) - @test r === - @constinferred(cartesianrange(2 × 3)) === - @constinferred(cartesianrange(Base.OneTo(2), Base.OneTo(3))) === - @constinferred(cartesianrange(Base.OneTo(2) × Base.OneTo(3))) - @test @constinferred(cartesianproduct(r)) === Base.OneTo(2) × Base.OneTo(3) - @test unproduct(r) === Base.OneTo(6) - @test length(r) == 6 - @test first(r) == 1 - @test last(r) == 6 - @test r[1 × 1] == 1 - @test r[1 × 2] == 2 - @test r[1 × 3] == 3 - @test r[2 × 1] == 4 - @test r[2 × 2] == 5 - @test r[2 × 3] == 6 - - @test sprint(show, "text/plain", cartesianrange(2 × 3)) == - "Base.OneTo(2) × Base.OneTo(3)\nBase.OneTo(6)" - @test sprint(show, cartesianrange(2 × 3)) == "Base.OneTo(6)" - - # CartesianProductUnitRange axes - r = cartesianrange((2:3) × (3:4), 2:5) - @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) - - # CartesianProductUnitRange getindex - r1 = cartesianrange((2:4) × (3:5), 2:10) - r2 = cartesianrange((2:3) × (2:3), 2:5) - @test r1[r2] ≡ cartesianrange((3:4) × (4:5), 3:6) - - @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) - - # CartesianProductVector axes - r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9]) - @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) - - r = @constinferred(cartesianrange(2 × 3, 2:7)) - @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) - @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) - @test unproduct(r) === 2:7 - @test length(r) == 6 - @test first(r) == 2 - @test last(r) == 7 - @test r[1 × 1] == 2 - @test r[1 × 2] == 3 - @test r[1 × 3] == 4 - @test r[2 × 1] == 5 - @test r[2 × 2] == 6 - @test r[2 × 3] == 7 - - # Test high-dimensional materialization. - a = randn(elt, 2, 2, 2) ⊗ randn(elt, 2, 2, 2) - x = Array(a) - y = similar(x) - for I in eachindex(a) - y[I] = @allowscalar x[I] - end - @test x == y - - rng = StableRNG(123) - a = @constinferred(randn(rng, elt, 2, 2) ⊗ randn(rng, elt, 3, 3)) - b = @constinferred(randn(rng, elt, 2, 2) ⊗ randn(rng, elt, 3, 3)) - c = @constinferred(a.arg1 ⊗ b.arg2) - @test a isa KroneckerArray{elt,2,typeof(a.arg1),typeof(a.arg2)} - @test similar(typeof(a), (2, 3)) isa Matrix{elt} - @test size(similar(typeof(a), (2, 3))) == (2, 3) - @test isreal(a) == (elt <: Real) - @test a[1 × 1, 1 × 1] == a.arg1[1, 1] * a.arg2[1, 1] - @test a[1 × 3, 2 × 1] == a.arg1[1, 2] * a.arg2[3, 1] - @test a[1 × (2:3), 2 × 1] == a.arg1[1, 2] * a.arg2[2:3, 1] - @test a[1 × :, (:) × 1] == a.arg1[1, :] ⊗ a.arg2[:, 1] - @test a[(1:2) × (2:3), (1:2) × (2:3)] == a.arg1[1:2, 1:2] ⊗ a.arg2[2:3, 2:3] - v = randn(elt, 2) ⊗ randn(elt, 3) - @test v[1 × 1] == v.arg1[1] * v.arg2[1] - @test v[1 × 3] == v.arg1[1] * v.arg2[3] - @test v[(1:2) × 3] == v.arg1[1:2] * v.arg2[3] - @test v[(1:2) × (2:3)] == v.arg1[1:2] ⊗ v.arg2[2:3] - @test eltype(a) === elt - @test collect(a) == kron(collect(a.arg1), collect(a.arg2)) - @test size(a) == (6, 6) - @test collect(a * b) ≈ collect(a) * collect(b) - @test collect(-a) == -collect(a) - @test collect(3 * a) ≈ 3 * collect(a) - @test collect(a * 3) ≈ collect(a) * 3 - @test collect(a / 3) ≈ collect(a) / 3 - @test a + a == 2a - @test iszero(a - a) - @test collect(a + c) ≈ collect(a) + collect(c) - @test collect(b + c) ≈ collect(b) + collect(c) - for f in (transpose, adjoint, inv, pinv) - @test collect(f(a)) ≈ f(collect(a)) - end - @test tr(a) ≈ tr(collect(a)) - @test norm(a) ≈ norm(collect(a)) - - # Views - a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) - b = @constinferred(view(a, (1:2) × (2:3), (1:2) × (2:3))) - @test arg1(b) === view(arg1(a), 1:2, 1:2) - @test arg1(b) == arg1(a)[1:2, 1:2] - @test arg2(b) === view(arg2(a), 2:3, 2:3) - @test arg2(b) == arg2(a)[2:3, 2:3] - - # Broadcasting - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - style = KroneckerStyle(BroadcastStyle(typeof(a.arg1)), BroadcastStyle(typeof(a.arg2))) - @test BroadcastStyle(typeof(a)) === style - @test_throws "not supported" sin.(a) - a′ = similar(a) - @test_throws "not supported" a′ .= sin.(a) - a′ = similar(a) - a′ .= 2 .* a - @test collect(a′) ≈ 2 * collect(a) - bc = broadcasted(+, a, a) - @test bc.style === style - @test similar(bc, elt) isa KroneckerArray{elt,2,typeof(a.arg1),typeof(a.arg2)} - @test collect(copy(bc)) ≈ 2 * collect(a) - bc = broadcasted(*, 2, a) - @test bc.style === style - @test collect(copy(bc)) ≈ 2 * collect(a) - - # Mapping - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - @test_throws "not supported" map(sin, a) - @test collect(map(Base.Fix1(*, 2), a)) ≈ 2 * collect(a) - a′ = similar(a) - @test_throws "not supported" map!(sin, a′, a) - a′ = similar(a) - map!(identity, a′, a) - @test collect(a′) ≈ collect(a) - a′ = similar(a) - map!(+, a′, a, a) - @test collect(a′) ≈ 2 * collect(a) - a′ = similar(a) - map!(-, a′, a, a) - @test norm(collect(a′)) ≈ 0 - a′ = similar(a) - map!(Base.Fix1(*, 2), a′, a) - @test collect(a′) ≈ 2 * collect(a) - a′ = similar(a) - map!(Base.Fix2(*, 2), a′, a) - @test collect(a′) ≈ 2 * collect(a) - a′ = similar(a) - map!(Base.Fix2(/, 2), a′, a) - @test collect(a′) ≈ collect(a) / 2 - a′ = similar(a) - map!(conj, a′, a) - @test collect(a′) ≈ conj(collect(a)) - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - if elt <: Real - @test real(a) == a - else - @test_throws ErrorException real(a) - end - if elt <: Real - @test iszero(imag(a)) - else - @test_throws ErrorException imag(a) - end - - # permutedims - a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) - @test permutedims(a, (2, 1, 3)) == - permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) - - # permutedims! - a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) - b = similar(a) - permutedims!(b, a, (2, 1, 3)) - @test b == permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) - - # Adapt - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - a′ = adapt(JLArray, a) - @test a′ isa KroneckerArray{elt,2,JLArray{elt,2},JLArray{elt,2}} - @test a′.arg1 isa JLArray{elt,2} - @test a′.arg2 isa JLArray{elt,2} - @test Array(a′.arg1) == a.arg1 - @test Array(a′.arg2) == a.arg2 - - a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) - @test collect(a) ≈ kron_nd(a.arg1, a.arg2) - @test a[1 × 1, 1 × 1, 1 × 1] == a.arg1[1, 1, 1] * a.arg2[1, 1, 1] - @test a[1 × 3, 2 × 1, 2 × 2] == a.arg1[1, 2, 2] * a.arg2[3, 1, 2] - @test collect(a + a) ≈ 2 * collect(a) - @test collect(2a) ≈ 2 * collect(a) - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c = a.arg1 ⊗ b.arg2 - U, S, V = svd(a) - @test collect(U * diagonal(S) * V') ≈ collect(a) - @test svdvals(a) ≈ S - @test sort(collect(S); rev=true) ≈ svdvals(collect(a)) - @test collect(U'U) ≈ I - @test collect(V * V') ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - D, V = eigen(a) - @test collect(a * V) ≈ collect(V * diagonal(D)) - @test eigvals(a) ≈ D - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - Q, R = qr(a) - @test collect(Q * R) ≈ collect(a) - @test collect(Q'Q) ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - @test det(a) ≈ det(collect(a)) - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - for f in KroneckerArrays.MATRIX_FUNCTIONS - @eval begin - @test_throws ArgumentError $f($a) + p = [1, 2] × [3, 4, 5] + @test length(p) == 6 + @test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5] + + r = @constinferred cartesianrange(2, 3) + @test r === + @constinferred(cartesianrange(2 × 3)) === + @constinferred(cartesianrange(Base.OneTo(2), Base.OneTo(3))) === + @constinferred(cartesianrange(Base.OneTo(2) × Base.OneTo(3))) + @test @constinferred(cartesianproduct(r)) === Base.OneTo(2) × Base.OneTo(3) + @test unproduct(r) === Base.OneTo(6) + @test length(r) == 6 + @test first(r) == 1 + @test last(r) == 6 + @test r[1 × 1] == 1 + @test r[1 × 2] == 2 + @test r[1 × 3] == 3 + @test r[2 × 1] == 4 + @test r[2 × 2] == 5 + @test r[2 × 3] == 6 + + @test sprint(show, "text/plain", cartesianrange(2 × 3)) == + "Base.OneTo(2) × Base.OneTo(3)\nBase.OneTo(6)" + @test sprint(show, cartesianrange(2 × 3)) == "Base.OneTo(6)" + + # CartesianProductUnitRange axes + r = cartesianrange((2:3) × (3:4), 2:5) + @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + + # CartesianProductUnitRange getindex + r1 = cartesianrange((2:4) × (3:5), 2:10) + r2 = cartesianrange((2:3) × (2:3), 2:5) + @test r1[r2] ≡ cartesianrange((3:4) × (4:5), 3:6) + + @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + + # CartesianProductVector axes + r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9]) + @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + + r = @constinferred(cartesianrange(2 × 3, 2:7)) + @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) + @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) + @test unproduct(r) === 2:7 + @test length(r) == 6 + @test first(r) == 2 + @test last(r) == 7 + @test r[1 × 1] == 2 + @test r[1 × 2] == 3 + @test r[1 × 3] == 4 + @test r[2 × 1] == 5 + @test r[2 × 2] == 6 + @test r[2 × 3] == 7 + + # Test high-dimensional materialization. + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 2, 2, 2) + x = Array(a) + y = similar(x) + for I in eachindex(a) + y[I] = @allowscalar x[I] + end + @test x == y + + rng = StableRNG(123) + a = @constinferred(randn(rng, elt, 2, 2) ⊗ randn(rng, elt, 3, 3)) + b = @constinferred(randn(rng, elt, 2, 2) ⊗ randn(rng, elt, 3, 3)) + c = @constinferred(a.arg1 ⊗ b.arg2) + @test a isa KroneckerArray{elt, 2, typeof(a.arg1), typeof(a.arg2)} + @test similar(typeof(a), (2, 3)) isa Matrix{elt} + @test size(similar(typeof(a), (2, 3))) == (2, 3) + @test isreal(a) == (elt <: Real) + @test a[1 × 1, 1 × 1] == a.arg1[1, 1] * a.arg2[1, 1] + @test a[1 × 3, 2 × 1] == a.arg1[1, 2] * a.arg2[3, 1] + @test a[1 × (2:3), 2 × 1] == a.arg1[1, 2] * a.arg2[2:3, 1] + @test a[1 × :, (:) × 1] == a.arg1[1, :] ⊗ a.arg2[:, 1] + @test a[(1:2) × (2:3), (1:2) × (2:3)] == a.arg1[1:2, 1:2] ⊗ a.arg2[2:3, 2:3] + v = randn(elt, 2) ⊗ randn(elt, 3) + @test v[1 × 1] == v.arg1[1] * v.arg2[1] + @test v[1 × 3] == v.arg1[1] * v.arg2[3] + @test v[(1:2) × 3] == v.arg1[1:2] * v.arg2[3] + @test v[(1:2) × (2:3)] == v.arg1[1:2] ⊗ v.arg2[2:3] + @test eltype(a) === elt + @test collect(a) == kron(collect(a.arg1), collect(a.arg2)) + @test size(a) == (6, 6) + @test collect(a * b) ≈ collect(a) * collect(b) + @test collect(-a) == -collect(a) + @test collect(3 * a) ≈ 3 * collect(a) + @test collect(a * 3) ≈ collect(a) * 3 + @test collect(a / 3) ≈ collect(a) / 3 + @test a + a == 2a + @test iszero(a - a) + @test collect(a + c) ≈ collect(a) + collect(c) + @test collect(b + c) ≈ collect(b) + collect(c) + for f in (transpose, adjoint, inv, pinv) + @test collect(f(a)) ≈ f(collect(a)) + end + @test tr(a) ≈ tr(collect(a)) + @test norm(a) ≈ norm(collect(a)) + + # Views + a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) + b = @constinferred(view(a, (1:2) × (2:3), (1:2) × (2:3))) + @test arg1(b) === view(arg1(a), 1:2, 1:2) + @test arg1(b) == arg1(a)[1:2, 1:2] + @test arg2(b) === view(arg2(a), 2:3, 2:3) + @test arg2(b) == arg2(a)[2:3, 2:3] + + # Broadcasting + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + style = KroneckerStyle(BroadcastStyle(typeof(a.arg1)), BroadcastStyle(typeof(a.arg2))) + @test BroadcastStyle(typeof(a)) === style + @test_throws "not supported" sin.(a) + a′ = similar(a) + @test_throws "not supported" a′ .= sin.(a) + a′ = similar(a) + a′ .= 2 .* a + @test collect(a′) ≈ 2 * collect(a) + bc = broadcasted(+, a, a) + @test bc.style === style + @test similar(bc, elt) isa KroneckerArray{elt, 2, typeof(a.arg1), typeof(a.arg2)} + @test collect(copy(bc)) ≈ 2 * collect(a) + bc = broadcasted(*, 2, a) + @test bc.style === style + @test collect(copy(bc)) ≈ 2 * collect(a) + + # Mapping + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test_throws "not supported" map(sin, a) + @test collect(map(Base.Fix1(*, 2), a)) ≈ 2 * collect(a) + a′ = similar(a) + @test_throws "not supported" map!(sin, a′, a) + a′ = similar(a) + map!(identity, a′, a) + @test collect(a′) ≈ collect(a) + a′ = similar(a) + map!(+, a′, a, a) + @test collect(a′) ≈ 2 * collect(a) + a′ = similar(a) + map!(-, a′, a, a) + @test norm(collect(a′)) ≈ 0 + a′ = similar(a) + map!(Base.Fix1(*, 2), a′, a) + @test collect(a′) ≈ 2 * collect(a) + a′ = similar(a) + map!(Base.Fix2(*, 2), a′, a) + @test collect(a′) ≈ 2 * collect(a) + a′ = similar(a) + map!(Base.Fix2(/, 2), a′, a) + @test collect(a′) ≈ collect(a) / 2 + a′ = similar(a) + map!(conj, a′, a) + @test collect(a′) ≈ conj(collect(a)) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + if elt <: Real + @test real(a) == a + else + @test_throws ErrorException real(a) + end + if elt <: Real + @test iszero(imag(a)) + else + @test_throws ErrorException imag(a) + end + + # permutedims + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + @test permutedims(a, (2, 1, 3)) == + permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + + # permutedims! + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + b = similar(a) + permutedims!(b, a, (2, 1, 3)) + @test b == permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + + # Adapt + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + a′ = adapt(JLArray, a) + @test a′ isa KroneckerArray{elt, 2, JLArray{elt, 2}, JLArray{elt, 2}} + @test a′.arg1 isa JLArray{elt, 2} + @test a′.arg2 isa JLArray{elt, 2} + @test Array(a′.arg1) == a.arg1 + @test Array(a′.arg2) == a.arg2 + + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + @test collect(a) ≈ kron_nd(a.arg1, a.arg2) + @test a[1 × 1, 1 × 1, 1 × 1] == a.arg1[1, 1, 1] * a.arg2[1, 1, 1] + @test a[1 × 3, 2 × 1, 2 × 2] == a.arg1[1, 2, 2] * a.arg2[3, 1, 2] + @test collect(a + a) ≈ 2 * collect(a) + @test collect(2a) ≈ 2 * collect(a) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c = a.arg1 ⊗ b.arg2 + U, S, V = svd(a) + @test collect(U * diagonal(S) * V') ≈ collect(a) + @test svdvals(a) ≈ S + @test sort(collect(S); rev = true) ≈ svdvals(collect(a)) + @test collect(U'U) ≈ I + @test collect(V * V') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + D, V = eigen(a) + @test collect(a * V) ≈ collect(V * diagonal(D)) + @test eigvals(a) ≈ D + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + Q, R = qr(a) + @test collect(Q * R) ≈ collect(a) + @test collect(Q'Q) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test det(a) ≈ det(collect(a)) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + for f in KroneckerArrays.MATRIX_FUNCTIONS + @eval begin + @test_throws ArgumentError $f($a) + end end - end end diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 69be9f4..3cf1d8d 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -1,7 +1,7 @@ using Adapt: adapt using BlockArrays: Block, BlockRange, blockedrange, blockisequal, mortar using BlockSparseArrays: - BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype, eachblockaxis + BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype, eachblockaxis using DiagonalArrays: DeltaMatrix, δ using JLArrays: JLArray using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2, cartesianrange @@ -14,412 +14,412 @@ using TestExtras: @constinferred, @constinferred_broken elts = (Float32, Float64, ComplexF32) arrayts = (Array, JLArray) @testset "BlockSparseArraysExt, KroneckerArray blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in - arrayts, - elt in elts - - # BlockUnitRange with CartesianProduct blocks - r = blockrange([2 × 3, 3 × 4]) - @test r[Block(1)] ≡ cartesianrange(2 × 3, 1:6) - @test r[Block(2)] ≡ cartesianrange(3 × 4, 7:18) - @test eachblockaxis(r)[1] ≡ cartesianrange(2, 3) - @test eachblockaxis(r)[2] ≡ cartesianrange(3, 4) - @test blockisequal(arg1(r), blockedrange([2, 3])) - @test blockisequal(arg2(r), blockedrange([3, 4])) - - r = blockrange([2 × 3, 3 × 4]) - r′ = r[Block.([2, 1])] - @test r′[Block(1)] ≡ cartesianrange(3 × 4, 7:18) - @test r′[Block(2)] ≡ cartesianrange(2 × 3, 1:6) - @test eachblockaxis(r′)[1] ≡ cartesianrange(3, 4) - @test eachblockaxis(r′)[2] ≡ cartesianrange(2, 3) - - dev = adapt(arrayt) - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - @test sprint(show, a) isa String - @test sprint(show, MIME("text/plain"), a) isa String - @test blocktype(a) === valtype(d) - @test a isa BlockSparseMatrix{elt,valtype(d)} - @test a[Block(1, 1)] == dev(d[Block(1, 1)]) - @test a[Block(1, 1)] isa valtype(d) - @test a[Block(2, 2)] == dev(d[Block(2, 2)]) - @test a[Block(2, 2)] isa valtype(d) - @test iszero(a[Block(2, 1)]) - @test a[Block(2, 1)] == dev(zeros(elt, 3, 2) ⊗ zeros(elt, 3, 2)) - @test a[Block(2, 1)] isa valtype(d) - @test iszero(a[Block(1, 2)]) - @test a[Block(1, 2)] == dev(zeros(elt, 2, 3) ⊗ zeros(elt, 2, 3)) - @test a[Block(1, 2)] isa valtype(d) - - # Slicing - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == - a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] - @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == - a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] - - # Blockwise slicing, shows up in truncated block sparse matrix factorizations. - I1 = Block(1)[Base.Slice(Base.OneTo(2)) × [1]] - I2 = Block(2)[Base.Slice(Base.OneTo(3)) × [1, 3]] - I = [I1, I2] - b = a[I, I] - @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] - @test iszero(b[Block(2, 1)]) - @test iszero(b[Block(1, 2)]) - @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] - - # Slicing - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - i1 = Block(1)[(1:2) × (1:2)] - i2 = Block(2)[(2:3) × (2:3)] - I = mortar([i1, i2]) - b = @view a[I, I] - b′ = copy(b) - @test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] - @test_broken b[Block(1, 2)] - - # Slicing - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - i1 = Block(1)[(1:2) × (1:2)] - i2 = Block(2)[(2:3) × (2:3)] - I = [i1, i2] - b = @view a[I, I] - b′ = copy(b) - @test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] - @test_broken b[Block(1, 2)] - - # Matrix multiplication - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - b = a * a - @test typeof(b) === typeof(a) - @test Array(b) ≈ Array(a) * Array(a) - - # Addition (mapping, broadcasting) - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - b = a + a - @test typeof(b) === typeof(a) - @test Array(b) ≈ Array(a) + Array(a) - - # Scaling (mapping, broadcasting) - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - b = 3a - @test typeof(b) === typeof(a) - @test Array(b) ≈ 3Array(a) - - # Dividing (mapping, broadcasting) - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - b = a / 3 - @test typeof(b) === typeof(a) - @test Array(b) ≈ Array(a) / 3 - - # Norm - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - @test norm(a) ≈ norm(Array(a)) - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - if arrayt === Array - @test Array(inv(a)) ≈ inv(Array(a)) - else - # Broken on GPU. - @test_broken inv(a) - end - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - u, s, v = svd_compact(a) - @test Array(u * s * v) ≈ Array(a) - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - b = a[Block.(1:2), Block(2)] - @test b[Block(1)] == a[Block(1, 2)] - @test b[Block(2)] == a[Block(2, 2)] - - # Broken operations - @test_broken exp(a) + arrayts, + elt in elts + + # BlockUnitRange with CartesianProduct blocks + r = blockrange([2 × 3, 3 × 4]) + @test r[Block(1)] ≡ cartesianrange(2 × 3, 1:6) + @test r[Block(2)] ≡ cartesianrange(3 × 4, 7:18) + @test eachblockaxis(r)[1] ≡ cartesianrange(2, 3) + @test eachblockaxis(r)[2] ≡ cartesianrange(3, 4) + @test blockisequal(arg1(r), blockedrange([2, 3])) + @test blockisequal(arg2(r), blockedrange([3, 4])) + + r = blockrange([2 × 3, 3 × 4]) + r′ = r[Block.([2, 1])] + @test r′[Block(1)] ≡ cartesianrange(3 × 4, 7:18) + @test r′[Block(2)] ≡ cartesianrange(2 × 3, 1:6) + @test eachblockaxis(r′)[1] ≡ cartesianrange(3, 4) + @test eachblockaxis(r′)[2] ≡ cartesianrange(2, 3) + + dev = adapt(arrayt) + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + @test sprint(show, a) isa String + @test sprint(show, MIME("text/plain"), a) isa String + @test blocktype(a) === valtype(d) + @test a isa BlockSparseMatrix{elt, valtype(d)} + @test a[Block(1, 1)] == dev(d[Block(1, 1)]) + @test a[Block(1, 1)] isa valtype(d) + @test a[Block(2, 2)] == dev(d[Block(2, 2)]) + @test a[Block(2, 2)] isa valtype(d) + @test iszero(a[Block(2, 1)]) + @test a[Block(2, 1)] == dev(zeros(elt, 3, 2) ⊗ zeros(elt, 3, 2)) + @test a[Block(2, 1)] isa valtype(d) + @test iszero(a[Block(1, 2)]) + @test a[Block(1, 2)] == dev(zeros(elt, 2, 3) ⊗ zeros(elt, 2, 3)) + @test a[Block(1, 2)] isa valtype(d) + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == + a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] + @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == + a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + + # Blockwise slicing, shows up in truncated block sparse matrix factorizations. + I1 = Block(1)[Base.Slice(Base.OneTo(2)) × [1]] + I2 = Block(2)[Base.Slice(Base.OneTo(3)) × [1, 3]] + I = [I1, I2] + b = a[I, I] + @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] + @test iszero(b[Block(2, 1)]) + @test iszero(b[Block(1, 2)]) + @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = mortar([i1, i2]) + b = @view a[I, I] + b′ = copy(b) + @test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken b[Block(1, 2)] + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = [i1, i2] + b = @view a[I, I] + b′ = copy(b) + @test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken b[Block(1, 2)] + + # Matrix multiplication + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + b = a * a + @test typeof(b) === typeof(a) + @test Array(b) ≈ Array(a) * Array(a) + + # Addition (mapping, broadcasting) + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + b = a + a + @test typeof(b) === typeof(a) + @test Array(b) ≈ Array(a) + Array(a) + + # Scaling (mapping, broadcasting) + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + b = 3a + @test typeof(b) === typeof(a) + @test Array(b) ≈ 3Array(a) + + # Dividing (mapping, broadcasting) + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + b = a / 3 + @test typeof(b) === typeof(a) + @test Array(b) ≈ Array(a) / 3 + + # Norm + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + @test norm(a) ≈ norm(Array(a)) + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + if arrayt === Array + @test Array(inv(a)) ≈ inv(Array(a)) + else + # Broken on GPU. + @test_broken inv(a) + end + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + u, s, v = svd_compact(a) + @test Array(u * s * v) ≈ Array(a) + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + b = a[Block.(1:2), Block(2)] + @test b[Block(1)] == a[Block(1, 2)] + @test b[Block(2)] == a[Block(2, 2)] + + # Broken operations + @test_broken exp(a) end @testset "BlockSparseArraysExt, DeltaKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in - arrayts, - elt in elts - - dev = adapt(arrayt) - r = @constinferred blockrange([2 × 2, 2 × 3]) - d = Dict( - Block(1, 1) => δ(elt, (2, 2)) ⊗ dev(randn(elt, 2, 2)), - Block(2, 2) => δ(elt, (2, 2)) ⊗ dev(randn(elt, 3, 3)), - ) - a = @constinferred dev(blocksparse(d, (r, r))) - @test sprint(show, a) == sprint(show, Array(a)) - @test sprint(show, MIME("text/plain"), a) isa String - @test @constinferred(blocktype(a)) === valtype(d) - @test a isa BlockSparseMatrix{elt,valtype(d)} - @test @constinferred(a[Block(1, 1)]) == dev(d[Block(1, 1)]) - @test @constinferred(a[Block(1, 1)]) isa valtype(d) - @test @constinferred(a[Block(2, 2)]) == dev(d[Block(2, 2)]) - @test @constinferred(a[Block(2, 2)]) isa valtype(d) - @test @constinferred(iszero(a[Block(2, 1)])) - @test a[Block(2, 1)] == dev(δ(2, 2) ⊗ zeros(elt, 3, 2)) - @test a[Block(2, 1)] isa valtype(d) - @test @constinferred(iszero(a[Block(1, 2)])) - @test a[Block(1, 2)] == dev(δ(2, 2) ⊗ zeros(elt, 2, 3)) - @test a[Block(1, 2)] isa valtype(d) - - # Slicing - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == - a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] - @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == - a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] - - # Blockwise slicing, shows up in truncated block sparse matrix factorizations. - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - I1 = Block(1)[Base.Slice(Base.OneTo(2)) × [1]] - I2 = Block(2)[Base.Slice(Base.OneTo(3)) × [1, 3]] - I = [I1, I2] - b = a[I, I] - @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] - @test arg1(b[Block(1, 1)]) isa DeltaMatrix - @test iszero(b[Block(2, 1)]) - @test arg1(b[Block(2, 1)]) isa DeltaMatrix - @test iszero(b[Block(1, 2)]) - @test arg1(b[Block(1, 2)]) isa DeltaMatrix - @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] - @test arg1(b[Block(2, 2)]) isa DeltaMatrix - - # Slicing - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - i1 = Block(1)[(1:2) × (1:2)] - i2 = Block(2)[(2:3) × (2:3)] - I = mortar([i1, i2]) - b = @view a[I, I] - @test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] - @test_broken copy(b) - @test_broken b[Block(1, 2)] - - # Slicing - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - i1 = Block(1)[(1:2) × (1:2)] - i2 = Block(2)[(2:3) × (2:3)] - I = [i1, i2] - b = @view a[I, I] - @test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] - @test_broken copy(b) - @test_broken b[Block(1, 2)] - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - b = @constinferred a * a - @test typeof(b) === typeof(a) - @test Array(b) ≈ Array(a) * Array(a) - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - # Type inference is broken for this operation. - # b = @constinferred a + a - b = a + a - @test typeof(b) === typeof(a) - @test Array(b) ≈ Array(a) + Array(a) - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - # Type inference is broken for this operation. - # b = @constinferred 3a - b = 3a - @test typeof(b) === typeof(a) - @test Array(b) ≈ 3Array(a) - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - # Type inference is broken for this operation. - # b = @constinferred a / 3 - b = a / 3 - @test typeof(b) === typeof(a) - @test Array(b) ≈ Array(a) / 3 - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - if VERSION ≥ v"1.11-" - @test @constinferred(norm(a)) ≈ norm(Array(a)) - else - # Type inference fails in Julia 1.10. - @test @constinferred_broken(norm(a)) ≈ norm(Array(a)) - end - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - if arrayt === Array - b = @constinferred exp(a) - @test Array(b) ≈ exp(Array(a)) - else - @test_broken exp(a) - end - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - u, s, v = svd_compact(a) - @test u * s * v ≈ a - @test blocktype(u) >: blocktype(u) - @test eltype(u) === eltype(a) - @test blocktype(v) >: blocktype(a) - @test eltype(v) === eltype(a) - @test eltype(s) === real(eltype(a)) - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - if arrayt === Array - @test Array(inv(a)) ≈ inv(Array(a)) - else - # Broken on GPU. - @test_broken inv(a) - end - - r = blockrange([2 × 2, 3 × 3]) - d = Dict( - Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), - Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), - ) - a = dev(blocksparse(d, (r, r))) - b = a[Block.(1:2), Block(2)] - @test b[Block(1)] == a[Block(1, 2)] - @test b[Block(2)] == a[Block(2, 2)] - - # svd_trunc - dev = adapt(arrayt) - r = @constinferred blockrange([2 × 2, 3 × 3]) - rng = StableRNG(1234) - d = Dict( - Block(1, 1) => δ(elt, (2, 2)) ⊗ randn(rng, elt, 2, 2), - Block(2, 2) => δ(elt, (3, 3)) ⊗ randn(rng, elt, 3, 3), - ) - a = @constinferred dev(blocksparse(d, (r, r))) - if arrayt === Array - u, s, v = svd_trunc(a; trunc=(; maxrank=6)) - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - else - @test_broken svd_trunc(a; trunc=(; maxrank=6)) - end - - @testset "Block deficient" begin - da = Dict(Block(1, 1) => δ(elt, (2, 2)) ⊗ dev(randn(elt, 2, 2))) - a = @constinferred dev(blocksparse(da, (r, r))) - - db = Dict(Block(2, 2) => δ(elt, (3, 3)) ⊗ dev(randn(elt, 3, 3))) - b = @constinferred dev(blocksparse(db, (r, r))) - - @test Array(a + b) ≈ Array(a) + Array(b) - @test Array(2a) ≈ 2Array(a) - end + arrayts, + elt in elts + + dev = adapt(arrayt) + r = @constinferred blockrange([2 × 2, 2 × 3]) + d = Dict( + Block(1, 1) => δ(elt, (2, 2)) ⊗ dev(randn(elt, 2, 2)), + Block(2, 2) => δ(elt, (2, 2)) ⊗ dev(randn(elt, 3, 3)), + ) + a = @constinferred dev(blocksparse(d, (r, r))) + @test sprint(show, a) == sprint(show, Array(a)) + @test sprint(show, MIME("text/plain"), a) isa String + @test @constinferred(blocktype(a)) === valtype(d) + @test a isa BlockSparseMatrix{elt, valtype(d)} + @test @constinferred(a[Block(1, 1)]) == dev(d[Block(1, 1)]) + @test @constinferred(a[Block(1, 1)]) isa valtype(d) + @test @constinferred(a[Block(2, 2)]) == dev(d[Block(2, 2)]) + @test @constinferred(a[Block(2, 2)]) isa valtype(d) + @test @constinferred(iszero(a[Block(2, 1)])) + @test a[Block(2, 1)] == dev(δ(2, 2) ⊗ zeros(elt, 3, 2)) + @test a[Block(2, 1)] isa valtype(d) + @test @constinferred(iszero(a[Block(1, 2)])) + @test a[Block(1, 2)] == dev(δ(2, 2) ⊗ zeros(elt, 2, 3)) + @test a[Block(1, 2)] isa valtype(d) + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == + a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] + @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == + a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + + # Blockwise slicing, shows up in truncated block sparse matrix factorizations. + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + I1 = Block(1)[Base.Slice(Base.OneTo(2)) × [1]] + I2 = Block(2)[Base.Slice(Base.OneTo(3)) × [1, 3]] + I = [I1, I2] + b = a[I, I] + @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] + @test arg1(b[Block(1, 1)]) isa DeltaMatrix + @test iszero(b[Block(2, 1)]) + @test arg1(b[Block(2, 1)]) isa DeltaMatrix + @test iszero(b[Block(1, 2)]) + @test arg1(b[Block(1, 2)]) isa DeltaMatrix + @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] + @test arg1(b[Block(2, 2)]) isa DeltaMatrix + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = mortar([i1, i2]) + b = @view a[I, I] + @test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken copy(b) + @test_broken b[Block(1, 2)] + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = [i1, i2] + b = @view a[I, I] + @test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken copy(b) + @test_broken b[Block(1, 2)] + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + b = @constinferred a * a + @test typeof(b) === typeof(a) + @test Array(b) ≈ Array(a) * Array(a) + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + # Type inference is broken for this operation. + # b = @constinferred a + a + b = a + a + @test typeof(b) === typeof(a) + @test Array(b) ≈ Array(a) + Array(a) + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + # Type inference is broken for this operation. + # b = @constinferred 3a + b = 3a + @test typeof(b) === typeof(a) + @test Array(b) ≈ 3Array(a) + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + # Type inference is broken for this operation. + # b = @constinferred a / 3 + b = a / 3 + @test typeof(b) === typeof(a) + @test Array(b) ≈ Array(a) / 3 + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + if VERSION ≥ v"1.11-" + @test @constinferred(norm(a)) ≈ norm(Array(a)) + else + # Type inference fails in Julia 1.10. + @test @constinferred_broken(norm(a)) ≈ norm(Array(a)) + end + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + if arrayt === Array + b = @constinferred exp(a) + @test Array(b) ≈ exp(Array(a)) + else + @test_broken exp(a) + end + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + u, s, v = svd_compact(a) + @test u * s * v ≈ a + @test blocktype(u) >: blocktype(u) + @test eltype(u) === eltype(a) + @test blocktype(v) >: blocktype(a) + @test eltype(v) === eltype(a) + @test eltype(s) === real(eltype(a)) + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + if arrayt === Array + @test Array(inv(a)) ≈ inv(Array(a)) + else + # Broken on GPU. + @test_broken inv(a) + end + + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(δ(elt, (2, 2)) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(δ(elt, (3, 3)) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, (r, r))) + b = a[Block.(1:2), Block(2)] + @test b[Block(1)] == a[Block(1, 2)] + @test b[Block(2)] == a[Block(2, 2)] + + # svd_trunc + dev = adapt(arrayt) + r = @constinferred blockrange([2 × 2, 3 × 3]) + rng = StableRNG(1234) + d = Dict( + Block(1, 1) => δ(elt, (2, 2)) ⊗ randn(rng, elt, 2, 2), + Block(2, 2) => δ(elt, (3, 3)) ⊗ randn(rng, elt, 3, 3), + ) + a = @constinferred dev(blocksparse(d, (r, r))) + if arrayt === Array + u, s, v = svd_trunc(a; trunc = (; maxrank = 6)) + u′, s′, v′ = svd_trunc(Matrix(a); trunc = (; maxrank = 5)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + else + @test_broken svd_trunc(a; trunc = (; maxrank = 6)) + end + + @testset "Block deficient" begin + da = Dict(Block(1, 1) => δ(elt, (2, 2)) ⊗ dev(randn(elt, 2, 2))) + a = @constinferred dev(blocksparse(da, (r, r))) + + db = Dict(Block(2, 2) => δ(elt, (3, 3)) ⊗ dev(randn(elt, 3, 3))) + b = @constinferred dev(blocksparse(db, (r, r))) + + @test Array(a + b) ≈ Array(a) + Array(b) + @test Array(2a) ≈ 2Array(a) + end end diff --git a/test/test_delta.jl b/test/test_delta.jl index ac6f990..256175a 100644 --- a/test/test_delta.jl +++ b/test/test_delta.jl @@ -10,405 +10,405 @@ using Test: @test, @test_broken, @test_throws, @testset using TestExtras: @constinferred @testset "FillArrays.Eye, DiagonalArrays.Delta" begin - MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS - if VERSION < v"1.11-" - # `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11. - MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt]) - end - - a = Eye(2) ⊗ randn(3, 3) - @test size(a) == (6, 6) - @test a + a == Eye(2) ⊗ (2 * arg2(a)) - @test 2a == Eye(2) ⊗ (2 * arg2(a)) - @test a * a == Eye(2) ⊗ (arg2(a) * arg2(a)) - @test_broken arg1(a[(:) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg1(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) - @test_broken arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) - @test_broken arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ - Eye(2) - @test_broken arg1( - view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)) - ) ≡ Eye(2) - @test arg1(adapt(JLArray, a)) ≡ Eye(2) - @test arg2(adapt(JLArray, a)) == jl(arg2(a)) - @test arg2(adapt(JLArray, a)) isa JLArray - @test_broken arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) - @test_broken arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ - Eye(3) - @test_broken arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ - Eye{Float32}(3) - @test arg1(copy(a)) ≡ Eye(2) - @test arg2(copy(a)) == arg2(a) - b = similar(a) - @test arg1(copyto!(b, a)) ≡ Eye(2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg1(permutedims(a, (2, 1))) ≡ Eye(2) - @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) - b = similar(a) - @test arg1(permutedims!(b, a, (2, 1))) ≡ Eye(2) - @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) - - a = randn(3, 3) ⊗ Eye(2) - @test size(a) == (6, 6) - @test a + a == (2 * arg1(a)) ⊗ Eye(2) - @test 2a == (2 * arg1(a)) ⊗ Eye(2) - @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) - @test_broken arg2(a[(:) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg2(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) - @test_broken arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) - @test_broken arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ - Eye(2) - @test_broken arg2( - view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)) - ) ≡ Eye(2) - @test arg2(adapt(JLArray, a)) ≡ Eye(2) - @test arg1(adapt(JLArray, a)) == jl(arg1(a)) - @test arg1(adapt(JLArray, a)) isa JLArray - @test_broken arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ Eye(3) - @test_broken arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ - Eye(3) - @test_broken arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ - Eye{Float32}(3) - @test arg2(copy(a)) ≡ Eye(2) - @test arg2(copy(a)) == arg2(a) - b = similar(a) - @test arg2(copyto!(b, a)) ≡ Eye(2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg2(permutedims(a, (2, 1))) ≡ Eye(2) - @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) - b = similar(a) - @test arg2(permutedims!(b, a, (2, 1))) ≡ Eye(2) - @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) - - a = δ(2, 2) ⊗ randn(3, 3) - @test size(a) == (6, 6) - @test a + a == δ(2, 2) ⊗ (2 * arg2(a)) - @test 2a == δ(2, 2) ⊗ (2 * arg2(a)) - @test a * a == δ(2, 2) ⊗ (arg2(a) * arg2(a)) - @test arg1(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) - @test arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) - @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ - δ(2, 2) - @test arg1(adapt(JLArray, a)) ≡ δ(2, 2) - @test arg2(adapt(JLArray, a)) == jl(arg2(a)) - @test arg2(adapt(JLArray, a)) isa JLArray - @test_broken arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) - @test_broken arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ - δ(3, 3) - @test_broken arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ - δ(Float32, 3, 3) - @test arg1(copy(a)) ≡ δ(2, 2) - @test arg2(copy(a)) == arg2(a) - b = similar(a) - @test arg1(copyto!(b, a)) ≡ δ(2, 2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg1(permutedims(a, (2, 1))) ≡ δ(2, 2) - @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) - b = similar(a) - @test arg1(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) - @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) - - a = randn(3, 3) ⊗ δ(2, 2) - @test size(a) == (6, 6) - @test a + a == (2 * arg1(a)) ⊗ δ(2, 2) - @test 2a == (2 * arg1(a)) ⊗ δ(2, 2) - @test a * a == (arg1(a) * arg1(a)) ⊗ δ(2, 2) - @test arg2(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, (:) × (:), (:) × (:))) ≡ δ(2, 2) - @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) - @test arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) - @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ - δ(2, 2) - @test arg2(adapt(JLArray, a)) ≡ δ(2, 2) - @test arg1(adapt(JLArray, a)) == jl(arg1(a)) - @test arg1(adapt(JLArray, a)) isa JLArray - @test_broken arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ δ(3, 3) - @test_broken arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ - δ(3, 3) - @test_broken arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ - δ(Float32, (3, 3)) - @test arg2(copy(a)) ≡ δ(2, 2) - @test arg2(copy(a)) == arg2(a) - b = similar(a) - @test arg2(copyto!(b, a)) ≡ δ(2, 2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg2(permutedims(a, (2, 1))) ≡ δ(2, 2) - @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) - b = similar(a) - @test arg2(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) - @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) - - # Views - a = @constinferred(Eye(2) ⊗ randn(3, 3)) - b = @constinferred(view(a, (:) × (2:3), (:) × (2:3))) - @test_broken arg1(b) ≡ Eye(2) - @test arg2(b) ≡ view(arg2(a), 2:3, 2:3) - @test arg2(b) == arg2(a)[2:3, 2:3] - - a = randn(3, 3) ⊗ Eye(2) - @test size(a) == (6, 6) - @test a + a == (2arg1(a)) ⊗ Eye(2) - @test 2a == (2arg1(a)) ⊗ Eye(2) - @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) - - # Views - a = @constinferred(randn(3, 3) ⊗ Eye(2)) - b = @constinferred(view(a, (2:3) × (:), (2:3) × (:))) - @test arg1(b) ≡ view(arg1(a), 2:3, 2:3) - @test arg1(b) == arg1(a)[2:3, 2:3] - @test_broken arg2(b) ≡ Eye(2) - - # similar - a = Eye(2) ⊗ randn(3, 3) - a′ = similar(a) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg1(a′) ≡ arg1(a) - - a = Eye(2) ⊗ randn(3, 3) - a′ = similar(a, eltype(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg1(a′) ≡ arg1(a) - - a = Eye(2) ⊗ randn(3, 3) - a′ = similar(a, axes(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg1(a′) ≡ arg1(a) - - a = Eye(2) ⊗ randn(3, 3) - a′ = similar(a, eltype(a), axes(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg1(a′) ≡ arg1(a) - - @test_broken similar(typeof(a), axes(a)) - - a = Eye(2) ⊗ randn(3, 3) - a′ = similar(a, Float32) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - @test_broken arg1(a′) ≡ Eye{Float32}(2) - - a = Eye(2) ⊗ randn(3, 3) - a′ = similar(a, Float32, axes(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - @test_broken arg1(a′) ≡ Eye{Float32}(2) - - a = randn(3, 3) ⊗ Eye(2) - a′ = similar(a) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg2(a′) ≡ arg2(a) - - a = randn(3, 3) ⊗ Eye(2) - a′ = similar(a, eltype(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg2(a′) ≡ arg2(a) - - a = randn(3, 3) ⊗ Eye(2) - a′ = similar(a, axes(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg2(a′) ≡ arg2(a) - - a = randn(3, 3) ⊗ Eye(2) - a′ = similar(a, eltype(a), axes(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - @test arg2(a′) ≡ arg2(a) - - @test_broken similar(typeof(a), axes(a)) - - a = randn(3, 3) ⊗ Eye(2) - a′ = similar(a, Float32) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - # This is broken because of: - # https://github.com/JuliaArrays/FillArrays.jl/issues/415 - @test_broken arg2(a′) ≡ Eye{Float32}(2) - - a = randn(3, 3) ⊗ Eye(2) - a′ = similar(a, Float32, axes(a)) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - - a = Eye(3) ⊗ Eye(2) - for a′ in ( - similar(a), similar(a, eltype(a)), similar(a, axes(a)), similar(a, eltype(a), axes(a)) - ) + MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS + if VERSION < v"1.11-" + # `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11. + MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt]) + end + + a = Eye(2) ⊗ randn(3, 3) + @test size(a) == (6, 6) + @test a + a == Eye(2) ⊗ (2 * arg2(a)) + @test 2a == Eye(2) ⊗ (2 * arg2(a)) + @test a * a == Eye(2) ⊗ (arg2(a) * arg2(a)) + @test_broken arg1(a[(:) × (:), (:) × (:)]) ≡ Eye(2) + @test_broken arg1(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) + @test_broken arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) + @test_broken arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) + @test_broken arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test_broken arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) + @test_broken arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ + Eye(2) + @test_broken arg1( + view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)) + ) ≡ Eye(2) + @test arg1(adapt(JLArray, a)) ≡ Eye(2) + @test arg2(adapt(JLArray, a)) == jl(arg2(a)) + @test arg2(adapt(JLArray, a)) isa JLArray + @test_broken arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) + @test_broken arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + Eye(3) + @test_broken arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + Eye{Float32}(3) + @test arg1(copy(a)) ≡ Eye(2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg1(copyto!(b, a)) ≡ Eye(2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg1(permutedims(a, (2, 1))) ≡ Eye(2) + @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + b = similar(a) + @test arg1(permutedims!(b, a, (2, 1))) ≡ Eye(2) + @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + + a = randn(3, 3) ⊗ Eye(2) + @test size(a) == (6, 6) + @test a + a == (2 * arg1(a)) ⊗ Eye(2) + @test 2a == (2 * arg1(a)) ⊗ Eye(2) + @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) + @test_broken arg2(a[(:) × (:), (:) × (:)]) ≡ Eye(2) + @test_broken arg2(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) + @test_broken arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) + @test_broken arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) + @test_broken arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test_broken arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) + @test_broken arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ + Eye(2) + @test_broken arg2( + view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)) + ) ≡ Eye(2) + @test arg2(adapt(JLArray, a)) ≡ Eye(2) + @test arg1(adapt(JLArray, a)) == jl(arg1(a)) + @test arg1(adapt(JLArray, a)) isa JLArray + @test_broken arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ Eye(3) + @test_broken arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + Eye(3) + @test_broken arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + Eye{Float32}(3) + @test arg2(copy(a)) ≡ Eye(2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg2(copyto!(b, a)) ≡ Eye(2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg2(permutedims(a, (2, 1))) ≡ Eye(2) + @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + b = similar(a) + @test arg2(permutedims!(b, a, (2, 1))) ≡ Eye(2) + @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) + + a = δ(2, 2) ⊗ randn(3, 3) + @test size(a) == (6, 6) + @test a + a == δ(2, 2) ⊗ (2 * arg2(a)) + @test 2a == δ(2, 2) ⊗ (2 * arg2(a)) + @test a * a == δ(2, 2) ⊗ (arg2(a) * arg2(a)) + @test arg1(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + δ(2, 2) + @test arg1(adapt(JLArray, a)) ≡ δ(2, 2) + @test arg2(adapt(JLArray, a)) == jl(arg2(a)) + @test arg2(adapt(JLArray, a)) isa JLArray + @test_broken arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) + @test_broken arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + δ(3, 3) + @test_broken arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + δ(Float32, 3, 3) + @test arg1(copy(a)) ≡ δ(2, 2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg1(copyto!(b, a)) ≡ δ(2, 2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg1(permutedims(a, (2, 1))) ≡ δ(2, 2) + @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + b = similar(a) + @test arg1(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) + @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + + a = randn(3, 3) ⊗ δ(2, 2) + @test size(a) == (6, 6) + @test a + a == (2 * arg1(a)) ⊗ δ(2, 2) + @test 2a == (2 * arg1(a)) ⊗ δ(2, 2) + @test a * a == (arg1(a) * arg1(a)) ⊗ δ(2, 2) + @test arg2(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, (:) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + δ(2, 2) + @test arg2(adapt(JLArray, a)) ≡ δ(2, 2) + @test arg1(adapt(JLArray, a)) == jl(arg1(a)) + @test arg1(adapt(JLArray, a)) isa JLArray + @test_broken arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ δ(3, 3) + @test_broken arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + δ(3, 3) + @test_broken arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + δ(Float32, (3, 3)) + @test arg2(copy(a)) ≡ δ(2, 2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg2(copyto!(b, a)) ≡ δ(2, 2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg2(permutedims(a, (2, 1))) ≡ δ(2, 2) + @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + b = similar(a) + @test arg2(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) + @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) + + # Views + a = @constinferred(Eye(2) ⊗ randn(3, 3)) + b = @constinferred(view(a, (:) × (2:3), (:) × (2:3))) + @test_broken arg1(b) ≡ Eye(2) + @test arg2(b) ≡ view(arg2(a), 2:3, 2:3) + @test arg2(b) == arg2(a)[2:3, 2:3] + + a = randn(3, 3) ⊗ Eye(2) + @test size(a) == (6, 6) + @test a + a == (2arg1(a)) ⊗ Eye(2) + @test 2a == (2arg1(a)) ⊗ Eye(2) + @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) + + # Views + a = @constinferred(randn(3, 3) ⊗ Eye(2)) + b = @constinferred(view(a, (2:3) × (:), (2:3) × (:))) + @test arg1(b) ≡ view(arg1(a), 2:3, 2:3) + @test arg1(b) == arg1(a)[2:3, 2:3] + @test_broken arg2(b) ≡ Eye(2) + + # similar + a = Eye(2) ⊗ randn(3, 3) + a′ = similar(a) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg1(a′) ≡ arg1(a) + + a = Eye(2) ⊗ randn(3, 3) + a′ = similar(a, eltype(a)) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg1(a′) ≡ arg1(a) + + a = Eye(2) ⊗ randn(3, 3) + a′ = similar(a, axes(a)) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg1(a′) ≡ arg1(a) + + a = Eye(2) ⊗ randn(3, 3) + a′ = similar(a, eltype(a), axes(a)) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg1(a′) ≡ arg1(a) + + @test_broken similar(typeof(a), axes(a)) + + a = Eye(2) ⊗ randn(3, 3) + a′ = similar(a, Float32) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32, ndims(a)} + @test_broken arg1(a′) ≡ Eye{Float32}(2) + + a = Eye(2) ⊗ randn(3, 3) + a′ = similar(a, Float32, axes(a)) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32, ndims(a)} + @test_broken arg1(a′) ≡ Eye{Float32}(2) + + a = randn(3, 3) ⊗ Eye(2) + a′ = similar(a) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg2(a′) ≡ arg2(a) + + a = randn(3, 3) ⊗ Eye(2) + a′ = similar(a, eltype(a)) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg2(a′) ≡ arg2(a) + + a = randn(3, 3) ⊗ Eye(2) + a′ = similar(a, axes(a)) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg2(a′) ≡ arg2(a) + + a = randn(3, 3) ⊗ Eye(2) + a′ = similar(a, eltype(a), axes(a)) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + @test arg2(a′) ≡ arg2(a) + + @test_broken similar(typeof(a), axes(a)) + + a = randn(3, 3) ⊗ Eye(2) + a′ = similar(a, Float32) @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a)} - end - @test_broken similar(typeof(a), axes(a)) + @test a′ isa KroneckerArray{Float32, ndims(a)} + # This is broken because of: + # https://github.com/JuliaArrays/FillArrays.jl/issues/415 + @test_broken arg2(a′) ≡ Eye{Float32}(2) - a = Eye(3) ⊗ Eye(2) - for args in ((Float32,), (Float32, axes(a))) - a′ = similar(a, args...) + a = randn(3, 3) ⊗ Eye(2) + a′ = similar(a, Float32, axes(a)) @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - end - - # DerivableInterfaces.zero! - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) - zero!(a) - @test iszero(a) - end - a = Eye(3) ⊗ Eye(2) - @test_throws ErrorException zero!(a) - - # map!(+, ...) - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + @test a′ isa KroneckerArray{Float32, ndims(a)} + + a = Eye(3) ⊗ Eye(2) + for a′ in ( + similar(a), similar(a, eltype(a)), similar(a, axes(a)), similar(a, eltype(a), axes(a)), + ) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a), ndims(a)} + end + @test_broken similar(typeof(a), axes(a)) + + a = Eye(3) ⊗ Eye(2) + for args in ((Float32,), (Float32, axes(a))) + a′ = similar(a, args...) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32, ndims(a)} + end + + # DerivableInterfaces.zero! + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + zero!(a) + @test iszero(a) + end + a = Eye(3) ⊗ Eye(2) + @test_throws ErrorException zero!(a) + + # map!(+, ...) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(+, a′, a, a) + @test collect(a′) ≈ 2 * collect(a) + end + a = Eye(3) ⊗ Eye(2) a′ = similar(a) map!(+, a′, a, a) - @test collect(a′) ≈ 2 * collect(a) - end - a = Eye(3) ⊗ Eye(2) - a′ = similar(a) - map!(+, a′, a, a) - @test a′ ≈ 2a - - # map!(-, ...) - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + @test a′ ≈ 2a + + # map!(-, ...) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(-, a′, a, a) + @test norm(collect(a′)) ≈ 0 + end + a = Eye(3) ⊗ Eye(2) a′ = similar(a) map!(-, a′, a, a) - @test norm(collect(a′)) ≈ 0 - end - a = Eye(3) ⊗ Eye(2) - a′ = similar(a) - map!(-, a′, a, a) - @test iszero(a′) - - # map!(-, b, a) - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + @test iszero(a′) + + # map!(-, b, a) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(-, a′, a) + @test collect(a′) ≈ -collect(a) + end + a = Eye(3) ⊗ Eye(2) a′ = similar(a) map!(-, a′, a) - @test collect(a′) ≈ -collect(a) - end - a = Eye(3) ⊗ Eye(2) - a′ = similar(a) - map!(-, a′, a) - @test a′ ≈ -a - - ## # Eye ⊗ A - ## rng = StableRNG(123) - ## a = Eye(2) ⊗ randn(rng, 3, 3) - ## for f in MATRIX_FUNCTIONS - ## @eval begin - ## fa = $f($a) - ## @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) - ## @test arg1(fa) isa Eye - ## end - ## end - - fa = inv(a) - @test collect(fa) ≈ inv(collect(a)) - @test arg1(fa) isa Eye - - fa = pinv(a) - @test collect(fa) ≈ pinv(collect(a)) - @test_broken arg1(fa) isa Eye - - @test det(a) ≈ det(collect(a)) - - ## # A ⊗ Eye - ## rng = StableRNG(123) - ## a = randn(rng, 3, 3) ⊗ Eye(2) - ## for f in setdiff(MATRIX_FUNCTIONS, [:atanh]) - ## @eval begin - ## fa = $f($a) - ## @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) - ## @test arg2(fa) isa Eye - ## end - ## end - - fa = inv(a) - @test collect(fa) ≈ inv(collect(a)) - @test arg2(fa) isa Eye - - fa = pinv(a) - @test collect(fa) ≈ pinv(collect(a)) - @test_broken arg2(fa) isa Eye - - @test det(a) ≈ det(collect(a)) - - # Eye ⊗ Eye - a = Eye(2) ⊗ Eye(2) - for f in MATRIX_FUNCTIONS - @eval begin - @test $f($a) == arg1($a) ⊗ $f(arg2($a)) + @test a′ ≈ -a + + ## # Eye ⊗ A + ## rng = StableRNG(123) + ## a = Eye(2) ⊗ randn(rng, 3, 3) + ## for f in MATRIX_FUNCTIONS + ## @eval begin + ## fa = $f($a) + ## @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) + ## @test arg1(fa) isa Eye + ## end + ## end + + fa = inv(a) + @test collect(fa) ≈ inv(collect(a)) + @test arg1(fa) isa Eye + + fa = pinv(a) + @test collect(fa) ≈ pinv(collect(a)) + @test_broken arg1(fa) isa Eye + + @test det(a) ≈ det(collect(a)) + + ## # A ⊗ Eye + ## rng = StableRNG(123) + ## a = randn(rng, 3, 3) ⊗ Eye(2) + ## for f in setdiff(MATRIX_FUNCTIONS, [:atanh]) + ## @eval begin + ## fa = $f($a) + ## @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) + ## @test arg2(fa) isa Eye + ## end + ## end + + fa = inv(a) + @test collect(fa) ≈ inv(collect(a)) + @test arg2(fa) isa Eye + + fa = pinv(a) + @test collect(fa) ≈ pinv(collect(a)) + @test_broken arg2(fa) isa Eye + + @test det(a) ≈ det(collect(a)) + + # Eye ⊗ Eye + a = Eye(2) ⊗ Eye(2) + for f in MATRIX_FUNCTIONS + @eval begin + @test $f($a) == arg1($a) ⊗ $f(arg2($a)) + end end - end - fa = inv(a) - @test fa == a - @test arg1(fa) isa Eye - @test arg2(fa) isa Eye + fa = inv(a) + @test fa == a + @test arg1(fa) isa Eye + @test arg2(fa) isa Eye - fa = pinv(a) - @test fa == a - @test_broken arg1(fa) isa Eye - @test_broken arg2(fa) isa Eye + fa = pinv(a) + @test fa == a + @test_broken arg1(fa) isa Eye + @test_broken arg2(fa) isa Eye - @test det(a) ≈ det(collect(a)) ≈ 1 + @test det(a) ≈ det(collect(a)) ≈ 1 - # permutedims - a = Eye(2, 2) ⊗ randn(3, 3) - @test permutedims(a, (2, 1)) == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + # permutedims + a = Eye(2, 2) ⊗ randn(3, 3) + @test permutedims(a, (2, 1)) == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) - a = randn(2, 2) ⊗ Eye(3, 3) - @test permutedims(a, (2, 1)) == permutedims(arg1(a), (2, 1)) ⊗ Eye(3, 3) + a = randn(2, 2) ⊗ Eye(3, 3) + @test permutedims(a, (2, 1)) == permutedims(arg1(a), (2, 1)) ⊗ Eye(3, 3) - # permutedims! - a = Eye(2, 2) ⊗ randn(3, 3) - b = similar(a) - permutedims!(b, a, (2, 1)) - @test b == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + # permutedims! + a = Eye(2, 2) ⊗ randn(3, 3) + b = similar(a) + permutedims!(b, a, (2, 1)) + @test b == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) - a = randn(3, 3) ⊗ Eye(2, 2) - b = similar(a) - permutedims!(b, a, (2, 1)) - @test b == permutedims(arg1(a), (2, 1)) ⊗ Eye(2, 2) + a = randn(3, 3) ⊗ Eye(2, 2) + b = similar(a) + permutedims!(b, a, (2, 1)) + @test b == permutedims(arg1(a), (2, 1)) ⊗ Eye(2, 2) end @testset "FillArrays.Zeros" begin - a = randn(2, 2) ⊗ randn(2, 2) - b = Zeros(2, 2) ⊗ Zeros(2, 2) - for (x, y) in ((a, b), (b, a)) - @test x + y == a - @test x .+ y == a - @test map!(+, similar(a), x, y) == a - @test (similar(a) .= x .+ y) == a - end - - @test a - b == a - @test a .- b == a - @test map!(-, similar(a), a, b) == a - @test (similar(a) .= a .- b) == a - - @test b - a == -a - @test b .- a == -a - @test map!(-, similar(a), b, a) == -a - @test (similar(a) .= b .- a) == -a - - @test b + b == b - @test b .+ b == b - @test b - b == b - @test b .- b == b + a = randn(2, 2) ⊗ randn(2, 2) + b = Zeros(2, 2) ⊗ Zeros(2, 2) + for (x, y) in ((a, b), (b, a)) + @test x + y == a + @test x .+ y == a + @test map!(+, similar(a), x, y) == a + @test (similar(a) .= x .+ y) == a + end + + @test a - b == a + @test a .- b == a + @test map!(-, similar(a), a, b) == a + @test (similar(a) .= a .- b) == a + + @test b - a == -a + @test b .- a == -a + @test map!(-, similar(a), b, a) == -a + @test (similar(a) .= b .- a) == -a + + @test b + b == b + @test b .+ b == b + @test b - b == b + @test b .- b == b end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index adc6974..0767542 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,129 +1,129 @@ using KroneckerArrays: ⊗, arguments using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm using MatrixAlgebraKit: - eig_full, - eig_trunc, - eig_vals, - eigh_full, - eigh_trunc, - eigh_vals, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - qr_compact, - qr_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full, - svd_trunc, - svd_vals + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, + left_null, + left_orth, + left_polar, + lq_compact, + lq_full, + qr_compact, + qr_full, + right_null, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_trunc, + svd_vals using Test: @test, @test_throws, @testset using TestExtras: @constinferred herm(a) = parent(hermitianpart(a)) @testset "MatrixAlgebraKit" begin - elt = Float32 - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - d, v = eig_full(a) - @test a * v ≈ v * d - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - @test_throws ArgumentError eig_trunc(a) - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - d = eig_vals(a) - @test d ≈ diag(eig_full(a)[1]) - - a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) - d, v = eigh_full(a) - @test a * v ≈ v * d - @test eltype(d) === real(elt) - @test eltype(v) === elt - - a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) - @test_throws ArgumentError eigh_trunc(a) - - a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) - d = eigh_vals(a) - @test d ≈ diag(eigh_full(a)[1]) - @test eltype(d) === real(elt) - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - u, c = qr_compact(a) - @test u * c ≈ a - @test collect(u'u) ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - u, c = qr_full(a) - @test u * c ≈ a - @test collect(u'u) ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c, u = lq_compact(a) - @test c * u ≈ a - @test collect(u * u') ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c, u = lq_full(a) - @test c * u ≈ a - @test collect(u * u') ≈ I - - a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) - n = left_null(a) - @test norm(n' * a) ≈ 0 atol = √eps(real(elt)) - - a = randn(elt, 2, 3) ⊗ randn(elt, 3, 4) - n = right_null(a) - @test norm(a * n') ≈ 0 atol = √eps(real(elt)) - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - u, c = left_orth(a) - @test u * c ≈ a - @test collect(u'u) ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c, u = right_orth(a) - @test c * u ≈ a - @test collect(u * u') ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - u, c = left_polar(a) - @test u * c ≈ a - @test collect(u'u) ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c, u = right_polar(a) - @test c * u ≈ a - @test collect(u * u') ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - u, s, v = svd_compact(a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test collect(u'u) ≈ I - @test collect(v * v') ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - u, s, v = svd_full(a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test collect(u'u) ≈ I - @test collect(v * v') ≈ I - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - @test_throws ArgumentError svd_trunc(a) - - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - s = svd_vals(a) - @test s ≈ diag(svd_compact(a)[2]) + elt = Float32 + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + d, v = eig_full(a) + @test a * v ≈ v * d + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test_throws ArgumentError eig_trunc(a) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + d = eig_vals(a) + @test d ≈ diag(eig_full(a)[1]) + + a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) + d, v = eigh_full(a) + @test a * v ≈ v * d + @test eltype(d) === real(elt) + @test eltype(v) === elt + + a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) + @test_throws ArgumentError eigh_trunc(a) + + a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) + d = eigh_vals(a) + @test d ≈ diag(eigh_full(a)[1]) + @test eltype(d) === real(elt) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = qr_compact(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = qr_full(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = lq_compact(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = lq_full(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) + n = left_null(a) + @test norm(n' * a) ≈ 0 atol = √eps(real(elt)) + + a = randn(elt, 2, 3) ⊗ randn(elt, 3, 4) + n = right_null(a) + @test norm(a * n') ≈ 0 atol = √eps(real(elt)) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = left_orth(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = right_orth(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = left_polar(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = right_polar(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, s, v = svd_compact(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test collect(u'u) ≈ I + @test collect(v * v') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, s, v = svd_full(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test collect(u'u) ≈ I + @test collect(v * v') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test_throws ArgumentError svd_trunc(a) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + s = svd_vals(a) + @test s ≈ diag(svd_compact(a)[2]) end diff --git a/test/test_matrixalgebrakit_delta.jl b/test/test_matrixalgebrakit_delta.jl index f4a8a61..a693b1e 100644 --- a/test/test_matrixalgebrakit_delta.jl +++ b/test/test_matrixalgebrakit_delta.jl @@ -3,283 +3,283 @@ using DiagonalArrays: δ, DeltaMatrix using KroneckerArrays: ⊗, arguments using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm using MatrixAlgebraKit: - eig_full, - eig_trunc, - eig_vals, - eigh_full, - eigh_trunc, - eigh_vals, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - qr_compact, - qr_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full, - svd_trunc, - svd_vals + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, + left_null, + left_orth, + left_polar, + lq_compact, + lq_full, + qr_compact, + qr_full, + right_null, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_trunc, + svd_vals using Test: @test, @test_broken, @test_throws, @testset using TestExtras: @constinferred herm(a) = parent(hermitianpart(a)) @testset "MatrixAlgebraKit + DeltaMatrix" begin - for elt in (Float32, ComplexF32) - a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) - d, v = @constinferred eig_full(a) - @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{complex(elt)} - @test arguments(v, 1) isa DeltaMatrix{complex(elt)} - - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ δ(elt, 3, 3) - d, v = @constinferred eig_full(a) - @test a * v ≈ v * d - @test arguments(d, 2) isa DeltaMatrix{complex(elt)} - @test arguments(v, 2) isa DeltaMatrix{complex(elt)} + for elt in (Float32, ComplexF32) + a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) + d, v = @constinferred eig_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa DeltaMatrix{complex(elt)} + @test arguments(v, 1) isa DeltaMatrix{complex(elt)} - a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) - d, v = @constinferred eig_full(a) - @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{complex(elt)} - @test arguments(d, 2) isa DeltaMatrix{complex(elt)} - @test arguments(v, 1) isa DeltaMatrix{complex(elt)} - @test arguments(v, 2) isa DeltaMatrix{complex(elt)} - end + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ δ(elt, 3, 3) + d, v = @constinferred eig_full(a) + @test a * v ≈ v * d + @test arguments(d, 2) isa DeltaMatrix{complex(elt)} + @test arguments(v, 2) isa DeltaMatrix{complex(elt)} - for elt in (Float32, ComplexF32) - a = δ(elt, 3, 3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) - d, v = @constinferred eigh_full($a) - @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} + a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) + d, v = @constinferred eig_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa DeltaMatrix{complex(elt)} + @test arguments(d, 2) isa DeltaMatrix{complex(elt)} + @test arguments(v, 1) isa DeltaMatrix{complex(elt)} + @test arguments(v, 2) isa DeltaMatrix{complex(elt)} + end - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ δ(elt, 3, 3) - d, v = @constinferred eigh_full($a) - @test a * v ≈ v * d - @test arguments(d, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} + for elt in (Float32, ComplexF32) + a = δ(elt, 3, 3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) + d, v = @constinferred eigh_full($a) + @test a * v ≈ v * d + @test arguments(d, 1) isa DeltaMatrix{real(elt)} + @test arguments(v, 1) isa DeltaMatrix{elt} - a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) - d, v = @constinferred eigh_full($a) - @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{real(elt)} - @test arguments(d, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} - @test arguments(v, 2) isa DeltaMatrix{elt} - end + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ δ(elt, 3, 3) + d, v = @constinferred eigh_full($a) + @test a * v ≈ v * d + @test arguments(d, 2) isa DeltaMatrix{real(elt)} + @test arguments(v, 2) isa DeltaMatrix{elt} - for f in (eig_trunc, eigh_trunc) - a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) - d, v = f(a; trunc=(; maxrank=7)) - @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix - @test arguments(v, 1) isa DeltaMatrix - @test size(d) == (6, 6) - @test size(v) == (9, 6) + a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) + d, v = @constinferred eigh_full($a) + @test a * v ≈ v * d + @test arguments(d, 1) isa DeltaMatrix{real(elt)} + @test arguments(d, 2) isa DeltaMatrix{real(elt)} + @test arguments(v, 1) isa DeltaMatrix{elt} + @test arguments(v, 2) isa DeltaMatrix{elt} + end - a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) - d, v = f(a; trunc=(; maxrank=7)) - @test a * v ≈ v * d - @test arguments(d, 2) isa DeltaMatrix - @test arguments(v, 2) isa DeltaMatrix - @test size(d) == (6, 6) - @test size(v) == (9, 6) + for f in (eig_trunc, eigh_trunc) + a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) + d, v = f(a; trunc = (; maxrank = 7)) + @test a * v ≈ v * d + @test arguments(d, 1) isa DeltaMatrix + @test arguments(v, 1) isa DeltaMatrix + @test size(d) == (6, 6) + @test size(v) == (9, 6) - a = δ(3, 3) ⊗ δ(3, 3) - @test_throws ArgumentError f(a) - end + a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) + d, v = f(a; trunc = (; maxrank = 7)) + @test a * v ≈ v * d + @test arguments(d, 2) isa DeltaMatrix + @test arguments(v, 2) isa DeltaMatrix + @test size(d) == (6, 6) + @test size(v) == (9, 6) - for f in (eig_vals, eigh_vals) - a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) - d = @constinferred f(a) - d′ = f(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) ≈ f(arguments(a, 2)) + a = δ(3, 3) ⊗ δ(3, 3) + @test_throws ArgumentError f(a) + end - a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) - d = @constinferred f(a) - d′ = f(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 2) isa Ones - @test arguments(d, 1) ≈ f(arguments(a, 1)) + for f in (eig_vals, eigh_vals) + a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) + d = @constinferred f(a) + d′ = f(Matrix(a)) + @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) ≈ f(arguments(a, 2)) - a = δ(3, 3) ⊗ δ(3, 3) - d = @constinferred f(a) - @test d == Ones(3) ⊗ Ones(3) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) isa Ones - end + a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) + d = @constinferred f(a) + d′ = f(Matrix(a)) + @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) + @test arguments(d, 2) isa Ones + @test arguments(d, 1) ≈ f(arguments(a, 1)) - for f in ( - left_orth, right_orth, left_polar, right_polar, qr_compact, lq_compact, qr_full, lq_full - ) - a = δ(3, 3) ⊗ randn(3, 3) - if VERSION ≥ v"1.11-" - x, y = @constinferred f($a) - else - # Type inference fails in Julia 1.10. - x, y = f(a) + a = δ(3, 3) ⊗ δ(3, 3) + d = @constinferred f(a) + @test d == Ones(3) ⊗ Ones(3) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) isa Ones end - @test x * y ≈ a - @test arguments(x, 1) isa DeltaMatrix - @test arguments(y, 1) isa DeltaMatrix - a = randn(3, 3) ⊗ δ(3, 3) - x, y = @constinferred f($a) - @test x * y ≈ a - @test arguments(x, 2) isa DeltaMatrix - @test arguments(y, 2) isa DeltaMatrix + for f in ( + left_orth, right_orth, left_polar, right_polar, qr_compact, lq_compact, qr_full, lq_full, + ) + a = δ(3, 3) ⊗ randn(3, 3) + if VERSION ≥ v"1.11-" + x, y = @constinferred f($a) + else + # Type inference fails in Julia 1.10. + x, y = f(a) + end + @test x * y ≈ a + @test arguments(x, 1) isa DeltaMatrix + @test arguments(y, 1) isa DeltaMatrix - a = δ(3, 3) ⊗ δ(3, 3) - x, y = @constinferred f($a) - @test x * y ≈ a - @test arguments(x, 1) isa DeltaMatrix - @test arguments(y, 1) isa DeltaMatrix - @test arguments(x, 2) isa DeltaMatrix - @test arguments(y, 2) isa DeltaMatrix - end + a = randn(3, 3) ⊗ δ(3, 3) + x, y = @constinferred f($a) + @test x * y ≈ a + @test arguments(x, 2) isa DeltaMatrix + @test arguments(y, 2) isa DeltaMatrix - for f in (svd_compact, svd_full) - for elt in (Float32, ComplexF32) - a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) - u, s, v = @constinferred f($a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test arguments(u, 1) isa DeltaMatrix{elt} - @test arguments(s, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} + a = δ(3, 3) ⊗ δ(3, 3) + x, y = @constinferred f($a) + @test x * y ≈ a + @test arguments(x, 1) isa DeltaMatrix + @test arguments(y, 1) isa DeltaMatrix + @test arguments(x, 2) isa DeltaMatrix + @test arguments(y, 2) isa DeltaMatrix + end + + for f in (svd_compact, svd_full) + for elt in (Float32, ComplexF32) + a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) + u, s, v = @constinferred f($a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 1) isa DeltaMatrix{elt} + @test arguments(s, 1) isa DeltaMatrix{real(elt)} + @test arguments(v, 1) isa DeltaMatrix{elt} - a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) - u, s, v = @constinferred f($a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test arguments(u, 2) isa DeltaMatrix{elt} - @test arguments(s, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} + a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) + u, s, v = @constinferred f($a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 2) isa DeltaMatrix{elt} + @test arguments(s, 2) isa DeltaMatrix{real(elt)} + @test arguments(v, 2) isa DeltaMatrix{elt} - a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) - u, s, v = @constinferred f($a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test arguments(u, 1) isa DeltaMatrix{elt} - @test arguments(s, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} - @test arguments(u, 2) isa DeltaMatrix{elt} - @test arguments(s, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} + a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) + u, s, v = @constinferred f($a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 1) isa DeltaMatrix{elt} + @test arguments(s, 1) isa DeltaMatrix{real(elt)} + @test arguments(v, 1) isa DeltaMatrix{elt} + @test arguments(u, 2) isa DeltaMatrix{elt} + @test arguments(s, 2) isa DeltaMatrix{real(elt)} + @test arguments(v, 2) isa DeltaMatrix{elt} + end end - end - # svd_trunc - for elt in (Float32, ComplexF32) - a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) - # TODO: Type inference is broken for `svd_trunc`, - # look into fixing it. - # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 1) isa DeltaMatrix{elt} - @test arguments(s, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) - end + # svd_trunc + for elt in (Float32, ComplexF32) + a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc = (; maxrank = 7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc = (; maxrank = 6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 1) isa DeltaMatrix{elt} + @test arguments(s, 1) isa DeltaMatrix{real(elt)} + @test arguments(v, 1) isa DeltaMatrix{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end - for elt in (Float32, ComplexF32) - a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) - # TODO: Type inference is broken for `svd_trunc`, - # look into fixing it. - # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 2) isa DeltaMatrix{elt} - @test arguments(s, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) - end + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc = (; maxrank = 7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc = (; maxrank = 6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 2) isa DeltaMatrix{elt} + @test arguments(s, 2) isa DeltaMatrix{real(elt)} + @test arguments(v, 2) isa DeltaMatrix{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end - a = δ(3, 3) ⊗ δ(3, 3) - @test_broken svd_trunc(a) + a = δ(3, 3) ⊗ δ(3, 3) + @test_broken svd_trunc(a) - # svd_vals - for elt in (Float32, ComplexF32) - a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) - d = @constinferred svd_vals(a) - d′ = svd_vals(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 1) isa Ones{real(elt)} - @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) - end + # svd_vals + for elt in (Float32, ComplexF32) + a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) + @test arguments(d, 1) isa Ones{real(elt)} + @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) + end - for elt in (Float32, ComplexF32) - a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) - d = @constinferred svd_vals(a) - d′ = svd_vals(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 2) isa Ones{real(elt)} - @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) - end + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) + @test arguments(d, 2) isa Ones{real(elt)} + @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) + end - for elt in (Float32, ComplexF32) - a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) - d = @constinferred svd_vals(a) - @test d ≡ Ones{real(elt)}(3) ⊗ Ones{real(elt)}(3) - @test arguments(d, 1) isa Ones{real(elt)} - @test arguments(d, 2) isa Ones{real(elt)} - end + for elt in (Float32, ComplexF32) + a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) + d = @constinferred svd_vals(a) + @test d ≡ Ones{real(elt)}(3) ⊗ Ones{real(elt)}(3) + @test arguments(d, 1) isa Ones{real(elt)} + @test arguments(d, 2) isa Ones{real(elt)} + end - # left_null - a = δ(3, 3) ⊗ randn(3, 3) - @test_broken left_null(a) - ## n = @constinferred left_null(a) - ## @test norm(n' * a) ≈ 0 - ## @test arguments(n, 1) isa DeltaMatrix + # left_null + a = δ(3, 3) ⊗ randn(3, 3) + @test_broken left_null(a) + ## n = @constinferred left_null(a) + ## @test norm(n' * a) ≈ 0 + ## @test arguments(n, 1) isa DeltaMatrix - a = randn(3, 3) ⊗ δ(3, 3) - @test_broken left_null(a) - ## n = @constinferred left_null(a) - ## @test norm(n' * a) ≈ 0 - ## @test arguments(n, 2) isa DeltaMatrix + a = randn(3, 3) ⊗ δ(3, 3) + @test_broken left_null(a) + ## n = @constinferred left_null(a) + ## @test norm(n' * a) ≈ 0 + ## @test arguments(n, 2) isa DeltaMatrix - a = δ(3, 3) ⊗ δ(3, 3) - @test_broken left_null(a) + a = δ(3, 3) ⊗ δ(3, 3) + @test_broken left_null(a) - # right_null - a = δ(3, 3) ⊗ randn(3, 3) - @test_broken right_null(a) - ## n = @constinferred right_null(a) - ## @test norm(a * n') ≈ 0 - ## @test arguments(n, 1) isa DeltaMatrix + # right_null + a = δ(3, 3) ⊗ randn(3, 3) + @test_broken right_null(a) + ## n = @constinferred right_null(a) + ## @test norm(a * n') ≈ 0 + ## @test arguments(n, 1) isa DeltaMatrix - a = randn(3, 3) ⊗ δ(3, 3) - @test_broken right_null(a) - ## n = @constinferred right_null(a) - ## @test norm(a * n') ≈ 0 - ## @test arguments(n, 2) isa DeltaMatrix + a = randn(3, 3) ⊗ δ(3, 3) + @test_broken right_null(a) + ## n = @constinferred right_null(a) + ## @test norm(a * n') ≈ 0 + ## @test arguments(n, 2) isa DeltaMatrix - a = δ(3, 3) ⊗ δ(3, 3) - @test_broken right_null(a) + a = δ(3, 3) ⊗ δ(3, 3) + @test_broken right_null(a) end diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl index 35cac2a..97e02a4 100644 --- a/test/test_tensoralgebra.jl +++ b/test/test_tensoralgebra.jl @@ -3,8 +3,8 @@ using KroneckerArrays: ⊗, arg1, arg2 using Test: @test, @testset @testset "TensorAlgebraExt" begin - a = randn(2, 2, 2) ⊗ randn(3, 3, 3) - m = matricize(a, (1, 2), (3,)) - @test m == matricize(arg1(a), (1, 2), (3,)) ⊗ matricize(arg2(a), (1, 2), (3,)) - @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a + a = randn(2, 2, 2) ⊗ randn(3, 3, 3) + m = matricize(a, (1, 2), (3,)) + @test m == matricize(arg1(a), (1, 2), (3,)) ⊗ matricize(arg2(a), (1, 2), (3,)) + @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a end diff --git a/test/test_tensorproducts.jl b/test/test_tensorproducts.jl index 3fa3c79..2812966 100644 --- a/test/test_tensorproducts.jl +++ b/test/test_tensorproducts.jl @@ -3,11 +3,11 @@ using TensorProducts: tensor_product using Test: @test, @testset @testset "KroneckerArraysTensorProductsExt" begin - r1 = cartesianrange(2, 3) - r2 = cartesianrange(4, 5) - r = tensor_product(r1, r2) - @test r ≡ cartesianrange(8, 15) - @test arg1(r) ≡ Base.OneTo(8) - @test arg2(r) ≡ Base.OneTo(15) - @test unproduct(r) ≡ Base.OneTo(120) + r1 = cartesianrange(2, 3) + r2 = cartesianrange(4, 5) + r = tensor_product(r1, r2) + @test r ≡ cartesianrange(8, 15) + @test arg1(r) ≡ Base.OneTo(8) + @test arg2(r) ≡ Base.OneTo(15) + @test unproduct(r) ≡ Base.OneTo(120) end From 3fadfa01d9cfc8ad71647176b5da8329c23d9676 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 2 Oct 2025 15:51:48 -0400 Subject: [PATCH 2/2] Update skeleton --- .JuliaFormatter.toml | 3 --- .github/workflows/FormatCheck.yml | 13 ++++++++----- .gitignore | 4 ++++ .pre-commit-config.yaml | 8 ++++---- 4 files changed, 16 insertions(+), 12 deletions(-) delete mode 100644 .JuliaFormatter.toml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 4c49a86..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options -style = "blue" -indent = 2 diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index 3f78afc..1525861 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,11 +1,14 @@ name: "Format Check" on: - push: - branches: - - 'main' - tags: '*' - pull_request: + pull_request_target: + paths: ['**/*.jl'] + types: [opened, synchronize, reopened, ready_for_review] + +permissions: + contents: read + actions: write + pull-requests: write jobs: format-check: diff --git a/.gitignore b/.gitignore index 10593a9..7085ca8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,10 @@ .vscode/ Manifest.toml benchmark/*.json +dev/ +docs/LocalPreferences.toml docs/Manifest.toml docs/build/ docs/src/index.md +examples/LocalPreferences.toml +test/LocalPreferences.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88bc8b4..3fc4743 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ ci: - skip: [julia-formatter] + skip: [runic] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -11,7 +11,7 @@ repos: - id: end-of-file-fixer exclude_types: [markdown] # incompatible with Literate.jl -- repo: "https://github.com/domluna/JuliaFormatter.jl" - rev: v2.1.6 +- repo: https://github.com/fredrikekre/runic-pre-commit + rev: v2.0.1 hooks: - - id: "julia-formatter" + - id: runic