From 88f6e13e9273712263e638dee87b22609cb14a5f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Nov 2025 10:32:08 -0500 Subject: [PATCH 1/8] small fixes --- src/kroneckerarray.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 9b93fb0..3d5d5af 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -121,7 +121,7 @@ end # TODO: copyto! is typically reserved for contiguous copies (i.e. also for copying from a # vector into an array), it might be better to not define that here. -function Base.copyto!(dest::KroneckerArray{<:Any, N}, src::KroneckerArray{<:Any, N}) where {N} +function Base.copyto!(dest::AbstractKroneckerArray{<:Any, N}, src::AbstractKroneckerArray{<:Any, N}) where {N} return mutate_active_args!(copyto!, copy, dest, src) end @@ -275,7 +275,7 @@ function DerivableInterfaces.zero!(a::AbstractKroneckerArray) end function Base.Array{T, N}(a::AbstractKroneckerArray{S, N}) where {T, S, N} - return convert(Array{T, N}, collect(T, a)) + return convert(Array{T, N}, collect(a)) end Base.size(a::AbstractKroneckerArray) = size(arg1(a)) .* size(arg2(a)) @@ -311,7 +311,7 @@ function Base.getindex(a::KroneckerArray, i::Integer) end using GPUArraysCore: GPUArraysCore -function Base.getindex(a::KroneckerArray{<:Any, N}, I::Vararg{Integer, N}) where {N} +function Base.getindex(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Integer, N}) where {N} GPUArraysCore.assertscalar("getindex") I′ = ntuple(Val(N)) do dim return cartesianproduct(axes(a, dim))[I[dim]] @@ -600,7 +600,7 @@ Broadcast.materialize(a::KroneckerBroadcasted) = copy(a) Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) Broadcast.broadcastable(a::KroneckerBroadcasted) = a Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a)) -function Base.copyto!(dest::KroneckerArray, src::KroneckerBroadcasted) +function Base.copyto!(dest::AbstractKroneckerArray, src::KroneckerBroadcasted) return mutate_active_args!(copyto!, copy, dest, src) end function Base.eltype(a::KroneckerBroadcasted) From 04b052f9be236e78ccd2f78e16a66235e7be37f4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Nov 2025 14:37:11 -0500 Subject: [PATCH 2/8] refactor interface --- .../KroneckerArraysBlockSparseArraysExt.jl | 82 +-- .../KroneckerArraysTensorAlgebraExt.jl | 22 +- .../KroneckerArraysTensorProductsExt.jl | 11 +- src/KroneckerArrays.jl | 45 ++ src/cartesianproduct.jl | 389 ++++++----- src/fillarrays.jl | 2 +- src/kroneckerarray.jl | 610 ++++++++---------- src/linearalgebra.jl | 163 ++--- src/matrixalgebrakit.jl | 195 +++--- 9 files changed, 734 insertions(+), 785 deletions(-) diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index fb44ce3..d41e852 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -1,70 +1,51 @@ module KroneckerArraysBlockSparseArraysExt -using BlockArrays: Block -using BlockSparseArrays: BlockIndexVector, GenericBlockIndex -using KroneckerArrays: CartesianPair, CartesianProduct -function Base.getindex( - b::Block{N}, - I::Vararg{Union{CartesianPair, CartesianProduct}, N} - ) where {N} - return GenericBlockIndex(b, I) -end -function Base.getindex(b::Block{N}, I::Vararg{CartesianProduct, N}) where {N} - return BlockIndexVector(b, I) -end +using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerVector, + CartesianPair, CartesianProduct, CartesianProductUnitRange, + kroneckerfactors, ⊗, isactive, cartesianrange +using BlockArrays: BlockArrays, Block, AbstractBlockedUnitRange, mortar +using BlockSparseArrays: BlockSparseArrays, BlockIndexVector, GenericBlockIndex, ZeroBlocks, + blockrange, eachblockaxis, mortar_axis +using DiagonalArrays: ShapeInitializer -using BlockSparseArrays: BlockSparseArrays, blockrange -using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange -function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair}) - return blockrange(map(cartesianrange, bs)) -end -function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) - return blockrange(map(cartesianrange, bs)) -end -using BlockArrays: BlockArrays, mortar -using BlockSparseArrays: blockrange -using KroneckerArrays: CartesianProductUnitRange +Base.getindex(b::Block{N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}) where {N} = + GenericBlockIndex(b, I) +Base.getindex(b::Block{N}, I::Vararg{CartesianProduct, N}) where {N} = + BlockIndexVector(b, I) + +BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair}) = blockrange(map(cartesianrange, bs)) +BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) = blockrange(map(cartesianrange, bs)) + # Makes sure that `mortar` results in a `BlockVector` with the correct # axes, otherwise the axes would not preserve the Kronecker structure. # This is helpful when indexing `BlockUnitRange`, for example: # https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.7.1/src/blockaxis.jl#L540-L547 -function BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange}) - return mortar(blocks, (blockrange(map(Base.axes1, blocks)),)) -end +BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange}) = + mortar(blocks, (blockrange(map(Base.axes1, blocks)),)) -using BlockArrays: AbstractBlockedUnitRange -using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis -using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2, isactive -function KroneckerArrays.arg1(r::AbstractBlockedUnitRange) - return mortar_axis(arg1.(eachblockaxis(r))) -end -function KroneckerArrays.arg2(r::AbstractBlockedUnitRange) - return mortar_axis(arg2.(eachblockaxis(r))) -end +KroneckerArrays.kroneckerfactors(r::AbstractBlockedUnitRange, i::Int) = + mortar_axis(kroneckerfactors.(eachblockaxis(r), i)) +KroneckerArrays.kroneckerfactors(r::AbstractBlockedUnitRange) = + (kroneckerfactors(r, 1), kroneckerfactors(r, 2)) -function block_axes( - ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Vararg{Block{1}, N} - ) where {N} +function block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Vararg{Block{1}, N}) where {N} return ntuple(N) do d return only(axes(ax[d][I[d]])) end end -function block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Block{N}) where {N} - return block_axes(ax, Tuple(I)...) -end - -using DiagonalArrays: ShapeInitializer +block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Block{N}) where {N} = + block_axes(ax, Tuple(I)...) ## TODO: Is this needed? function Base.getindex( a::ZeroBlocks{N, KroneckerArray{T, N, A1, A2}}, I::Vararg{Int, N} ) where {T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} - ax_a1 = map(arg1, a.parentaxes) - ax_a2 = map(arg2, a.parentaxes) - block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I))) - block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I))) + ax_a1 = kroneckerfactors.(a.parentaxes, 1) + ax_a2 = kroneckerfactors.(a.parentaxes, 2) + block_ax_a1 = kroneckerfactors.(block_axes(a.parentaxes, Block(I)), 1) + block_ax_a2 = kroneckerfactors.(block_axes(a.parentaxes, Block(I)), 2) # TODO: Is this a good definition? It is similar to # the definition of `similar` and `adapt_structure`. return if isactive(A1) == isactive(A2) @@ -76,10 +57,7 @@ function Base.getindex( end end -using BlockSparseArrays: BlockSparseArrays -using KroneckerArrays: KroneckerArrays, KroneckerVector -function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I) - return KroneckerArrays.to_truncated_indices(values, I) -end +BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I) = + KroneckerArrays.to_truncated_indices(values, I) end diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl index eb6587a..f997c8f 100644 --- a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -1,22 +1,22 @@ module KroneckerArraysTensorAlgebraExt -using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, arg1, arg2 -using TensorAlgebra: - TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize +using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, kroneckerfactors +using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, FusionStyle, + matricize, unmatricize struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle a::A b::B end -KroneckerArrays.arg1(style::KroneckerFusion) = style.a -KroneckerArrays.arg2(style::KroneckerFusion) = style.b -function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) - return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) -end +KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b) +KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B) + +TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) = KroneckerFusion(FusionStyle.(kroneckerfactors(a))...) function matricize_kronecker( style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} ) - return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm) + return matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), biperm) ⊗ + matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), biperm) end function TensorAlgebra.matricize( style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} @@ -32,8 +32,8 @@ function TensorAlgebra.matricize( return matricize_kronecker(style, a, biperm) end function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax) - return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗ - unmatricize(arg2(style), arg2(a), arg2.(ax)) + return unmatricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), kroneckerfactors.(ax, 1)) ⊗ + unmatricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), kroneckerfactors.(ax, 2)) end function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax) return unmatricize_kronecker(style, a, ax) diff --git a/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl b/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl index 4920d92..f227960 100644 --- a/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl +++ b/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl @@ -1,11 +1,14 @@ module KroneckerArraysTensorProductsExt -using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct using TensorProducts: TensorProducts, tensor_product +using KroneckerArrays: CartesianProductOneTo, kroneckerfactors, cartesianrange, unproduct + function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo) - prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2)) - range = tensor_product(unproduct(a1), unproduct(a2)) - return cartesianrange(prod, range) + return cartesianrange( + tensor_product(kroneckerfactors(a1, 1), kroneckerfactors(a2, 1)), + tensor_product(kroneckerfactors(a1, 2), kroneckerfactors(a2, 2)), + tensor_product(unproduct(a1), unproduct(a2)) + ) end end diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 0c74196..5698101 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -1,7 +1,52 @@ module KroneckerArrays +export kroneckerfactors, kroneckerfactortypes +export times, ×, cartesianproduct, cartesianrange, unproduct export ⊗, × +# Imports +# ------- +import Base.Broadcast as BC +using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag +using DiagonalArrays: DiagonalArrays +using DerivableInterfaces: DerivableInterfaces +using MapBroadcast: MapBroadcast, MapFunction, LinearCombination, Summed +using GPUArraysCore: GPUArraysCore +using Adapt: Adapt + +# Interfaces +# ---------- +@doc """ + kroneckerfactors(x) -> Tuple + kroneckerfactors(x, i) = kroneckerfactors(x)[i] + +Extract the factors of `x`, where `x` is an object that represents a lazily composed product type. +""" kroneckerfactors +# note: this is `Int` instead of `Integer` to avoid ambiguities downstream +@inline kroneckerfactors(x, i::Int) = kroneckerfactors(x)[i] + +@doc """ + kroneckerfactortypes(x) -> Tuple + kroneckerfactortypes(x, i) = kroneckerfactortypes(x)[i] + +Extract the types of the factors of `x`, where `x` is an object or type that represents a lazily composed product type. +""" kroneckerfactortypes +# note: this is `Int` instead of `Integer` to avoid ambiguities downstream +@inline kroneckerfactortypes(x, i::Int) = kroneckerfactortypes(x)[i] +kroneckerfactortypes(x) = kroneckerfactortypes(typeof(x)) +kroneckerfactortypes(T::Type) = throw(MethodError(kroneckerfactortypes, (T,))) + +@doc """ + ⊗(args...) + otimes(args...) + +Construct an object that represents the Kronecker product of the provided `args`. +""" otimes +function otimes(a, b) end +const ⊗ = otimes # unicode alternative + +# Includes +# -------- include("cartesianproduct.jl") include("kroneckerarray.jl") include("linearalgebra.jl") diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index 1e5dd79..ea1d2f0 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -1,196 +1,281 @@ -struct CartesianPair{A1, A2} - arg1::A1 - arg2::A2 -end -arguments(a::CartesianPair) = (arg1(a), arg2(a)) -arguments(a::CartesianPair, n::Int) = arguments(a)[n] - -arg1(a::CartesianPair) = getfield(a, :arg1) -arg2(a::CartesianPair) = getfield(a, :arg2) +# Cartesian product types +# ----------------------- +# This file contains several different definitions for cartesian product objects. +# The multiple types are required to get around Julia's type system not allowing parametric +# supertypes. -×(a1, a2) = CartesianPair(a1, a2) +""" + CartesianPair(a, b) -function Base.show(io::IO, a::CartesianPair) - print(io, arg1(a), " × ", arg2(a)) - return nothing +Represents a single element, the cartesian product of two arbitrary objects `a` and `b`. +""" +struct CartesianPair{A, B} + a::A + b::B end +""" + CartesianProduct(a::AbstractVector, b::AbstractVector) + +Represents the cartesian product of two collections `a` and `b`. +""" struct CartesianProduct{TA, TB, A <: AbstractVector{TA}, B <: AbstractVector{TB}} <: AbstractVector{CartesianPair{TA, TB}} a::A b::B end -arguments(a::CartesianProduct) = (arg1(a), arg2(a)) -arguments(a::CartesianProduct, n::Int) = arguments(a)[n] -arg1(a::CartesianProduct) = getfield(a, :a) -arg2(a::CartesianProduct) = getfield(a, :b) +""" + CartesianProductVector(a::AbstractVector, b::AbstractVector, values::AbstractVector{T}) <: AbstractVector{T} -Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a)) +Similar to the [`CartesianProduct`](@ref), this represents the cartesian product of two collections `a` and `b`. +However, as a vector it will behave as `values`, rather than `CartesianPair`s of the elements of `a` and `b`. +""" +struct CartesianProductVector{T, A, B, V <: AbstractVector{T}} <: AbstractVector{T} + a::A + b::B + values::V -function Base.show(io::IO, a::CartesianProduct) - print(io, arg1(a), " × ", arg2(a)) - return nothing -end -function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct) - show(io, a) - return nothing + function CartesianProductVector{T, A, B, V}( + a::A, b::B, values::V + ) where {T, A, B, V <: AbstractVector{T}} + length(a) * length(b) == length(values) || throw(DimensionMismatch()) + return new{T, A, B, V}(a, b, values) + end end +CartesianProductVector(a, b, values::AbstractVector{T}) where {T} = + CartesianProductVector{T, typeof(a), typeof(b), typeof(values)}(a, b, values) -×(a1::AbstractVector, a2::AbstractVector) = CartesianProduct(a1, a2) -Base.length(a::CartesianProduct) = length(arg1(a)) * length(arg2(a)) -Base.size(a::CartesianProduct) = (length(a),) +""" + CartesianProductUnitRange(a::AbstractUnitRange, b::AbstractUnitRange, range::AbstractUnitRange{T}) <: AbstractUnitRange{T} -function Base.getindex(a::CartesianProduct, i::CartesianProduct) - return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] -end -function Base.getindex(a::CartesianProduct, i::CartesianPair) - return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] -end -function Base.getindex(a::CartesianProduct, i::Int) - I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i]) - return a[I[2] × I[1]] -end +Similar to [`CartesianProductVector`](@ref), this represents the cartesian product of two ranges `a` and `b`. +However, as a range it will behave as `range`, rather than `CartesianPair`s of the elements of `a` and `b`. +""" +struct CartesianProductUnitRange{ + T, A <: AbstractUnitRange{T}, B <: AbstractUnitRange{T}, R <: AbstractUnitRange{T}, + } <: AbstractUnitRange{T} + a::A + b::B + range::R -struct CartesianProductVector{T, P <: CartesianProduct, V <: AbstractVector{T}} <: - AbstractVector{T} - product::P - values::V -end -cartesianproduct(r::CartesianProductVector) = getfield(r, :product) -unproduct(r::CartesianProductVector) = getfield(r, :values) -Base.length(a::CartesianProductVector) = length(unproduct(a)) -Base.size(a::CartesianProductVector) = (length(a),) -function Base.axes(r::CartesianProductVector) - prod = cartesianproduct(r) - prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) - return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) -end -function Base.copy(a::CartesianProductVector) - return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a))) -end -function Base.getindex(r::CartesianProductVector, i::Integer) - return unproduct(r)[i] + function CartesianProductUnitRange{T, A, B, R}( + a::A, b::B, range::R + ) where {T, A <: AbstractUnitRange{T}, B <: AbstractUnitRange{T}, R <: AbstractUnitRange{T}} + length(a) * length(b) == length(range) || throw(DimensionMismatch()) + return new{T, A, B, R}(a, b, range) + end end +CartesianProductUnitRange(a::AbstractUnitRange{T}, b::AbstractUnitRange{T}, range::AbstractUnitRange{T}) where {T} = + CartesianProductUnitRange{T, typeof(a), typeof(b), typeof(range)}(a, b, range) +CartesianProductUnitRange(a::AbstractUnitRange{T}, b::AbstractUnitRange{T}) where {T} = + CartesianProductUnitRange(a, b, Base.OneTo{T}(length(a) * length(b))) -function Base.show(io::IO, a::CartesianProductVector) - show(io, unproduct(a)) - return nothing -end -function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector) - show(io, mime, cartesianproduct(a)) - println(io) - show(io, mime, unproduct(a)) - return nothing -end +const CartesianProductOneTo{T, A <: AbstractUnitRange{T}, B <: AbstractUnitRange{T}} = + CartesianProductUnitRange{T, A, B, Base.OneTo{T}} -struct CartesianProductUnitRange{T, P <: CartesianProduct, R <: AbstractUnitRange{T}} <: - AbstractUnitRange{T} - product::P - range::R -end -Base.first(r::CartesianProductUnitRange) = first(r.range) -Base.last(r::CartesianProductUnitRange) = last(r.range) +const AnyCartesian = Union{CartesianPair, CartesianProduct, CartesianProductVector, CartesianProductUnitRange} -cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product) -unproduct(r::CartesianProductUnitRange) = getfield(r, :range) +# Utility constructors +# -------------------- +@doc """ + ×(args...) + times(args...) -arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) -arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) +Construct an object that represents the Cartesian product of the provided `args`. +By default this constructs the singular [`CartesianPair`](@ref) for unknown values, while attempting to promote to more structured types wherever possible. +See also [`CartesianProduct`](@ref), [`CartesianProductVector`](@ref) and [`CartesianProductUnitRange`](@ref). +""" times +# implement multi-argument version through a left fold +times(x) = x +times(x, y, z...) = foldl(times, (x, y, z...)) +const × = times # unicode alternative +# fallback definition for cartesian product +×(a, b) = CartesianPair(a, b) -function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange) - prod = cartesianproduct(a)[cartesianproduct(i)] - range = unproduct(a)[unproduct(i)] - return cartesianrange(prod, range) -end +# attempt to construct most specific type +×(a::AbstractVector, b::AbstractVector) = cartesianproduct(a, b) +×(a::AbstractUnitRange{T}, b::AbstractUnitRange{T}) where {T} = cartesianrange(a, b) -function Base.show(io::IO, a::CartesianProductUnitRange) - show(io, unproduct(a)) - return nothing -end -function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange) - show(io, mime, cartesianproduct(a)) - println(io) - show(io, mime, unproduct(a)) - return nothing -end +@doc """ + cartesianproduct(a::AbstractVector, b::AbstractVector, [values::AbstractVector])::AbstractVector + +Construct an `AbstractVector` that represents the cartesian product `a × b`, but behaves as `values`. +This behaves similar to [`×`](@ref), but forces promotion to a `AbstractVector`. +""" cartesianproduct + +cartesianproduct(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b) +cartesianproduct(a::AbstractVector, b::AbstractVector, values::AbstractVector) = CartesianProductVector(a, b, values) +cartesianproduct(p::CartesianPair) = cartesianproduct(kroneckerfactors(p)...) +cartesianproduct(p::CartesianPair, values::AbstractVector) = cartesianproduct(kroneckerfactors(p)..., values) + +@doc """ + cartesianrange(a::AbstractUnitRange, b::AbstractUnitRange, [range::AbstractUnitRange])::AbstractUnitRange + +Construct a `UnitRange` that represents the cartesian product `a × b`, but behaves as `range`. +This behaves similar to [`×`](@ref), but forces promotion to a `AbstractUnitRange`. +""" cartesianrange -function CartesianProductUnitRange(p::CartesianProduct) - return CartesianProductUnitRange(p, Base.OneTo(length(p))) -end -function CartesianProductUnitRange(a1, a2) - return CartesianProductUnitRange(a1 × a2) -end to_product_indices(a::AbstractVector) = a to_product_indices(i::Integer) = Base.OneTo(i) -cartesianrange(a1, a2) = cartesianrange(to_product_indices(a1) × to_product_indices(a2)) -function cartesianrange(p::CartesianPair) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return cartesianrange(p′) + +cartesianrange(a, b) = CartesianProductUnitRange(to_product_indices(a), to_product_indices(b)) +cartesianrange(a, b, range::AbstractUnitRange) = CartesianProductUnitRange(to_product_indices(a), to_product_indices(b), range) +cartesianrange(p::Union{CartesianPair, CartesianProduct}) = cartesianrange(kroneckerfactors(p)...) +cartesianrange(p::Union{CartesianPair, CartesianProduct}, range::AbstractUnitRange) = cartesianrange(kroneckerfactors(p)..., range) + +# KroneckerArrays interface +# ------------------------- +kroneckerfactors(ab::AnyCartesian) = (ab.a, ab.b) +kroneckerfactortypes(::Type{T}) where {T <: CartesianPair} = fieldtypes(T) +kroneckerfactortypes(::Type{T}) where {T <: CartesianProduct} = kroneckerfactortypes(eltype(T)) +kroneckerfactortypes(::Type{<:CartesianProductVector{T, A, B}}) where {T, A, B} = (A, B) +kroneckerfactortypes(::Type{<:CartesianProductUnitRange{T, A, B}}) where {T, A, B} = (A, B) + +@doc """ + unproduct(a) + +For an object that holds a cartesian product of indices and their corresponding values, +this function removes the cartesian product layer and returns only the values. +""" unproduct + +unproduct(ab::CartesianProduct) = collect(ab) +unproduct(ab::CartesianProductVector) = ab.values +unproduct(ab::CartesianProductUnitRange) = ab.range + +# AbstractVector interface +# ------------------------ +Base.size(a::CartesianProduct) = (prod(length, kroneckerfactors(a)),) +Base.size(a::CartesianProductVector) = size(unproduct(a)) +Base.size(a::CartesianProductUnitRange) = size(unproduct(a)) + +# function Base.axes(r::CartesianProduct) +# prod_ax = only.(axes.(kroneckerfactors(r))) +# return (cartesianrange(prod_ax...),) +# end +function Base.axes(r::CartesianProductVector) + prod_ax = only.(axes.(kroneckerfactors(r))) + return (cartesianrange(prod_ax..., only(axes(r.values))),) +end +function Base.axes(r::CartesianProductUnitRange) + prod_ax = only.(axes.(kroneckerfactors(r))) + return (cartesianrange(prod_ax..., only(axes(r.range))),) end -function cartesianrange(p::CartesianProduct) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return cartesianrange(p′, Base.OneTo(length(p′))) +# TODO: add comment why this is here +Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) +Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices +Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) + +Base.copy(a::CartesianProduct) = ×(copy.(kroneckerfactors(a)...)...) +Base.copy(a::CartesianProductVector) = cartesianproduct(copy.(kroneckerfactors(a))..., copy(unproduct(a))) + +@inline Base.getindex(a::CartesianProduct, i::CartesianProduct) = + ×(Base.getindex.(kroneckerfactors(a), kroneckerfactors(i))...) +@inline Base.getindex(a::CartesianProduct, i::CartesianPair) = + ×(Base.getindex.(kroneckerfactors(a), kroneckerfactors(i))...) + +Base.@propagate_inbounds function Base.getindex(a::CartesianProduct, i::Int) + I = Tuple(CartesianIndices(reverse(length.(kroneckerfactors(a))))[i]) + return a[I[2] × I[1]] end -function cartesianrange(p::CartesianPair, range::AbstractUnitRange) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return cartesianrange(p′, range) +@inline Base.getindex(r::CartesianProductVector, i::Int) = r.values[i] + +Base.@propagate_inbounds function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange) + return cartesianrange(Base.getindex.(kroneckerfactors(a), kroneckerfactors(i))..., a.range[i.range]) end -function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) - p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) - return CartesianProductUnitRange(p′, range) + +function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct) + return cartesianproduct(Base.getindex.(kroneckerfactors(a), kroneckerfactors(I))..., map(Base.Fix1(getindex, a), I)) end -function Base.axes(r::CartesianProductUnitRange) - prod = cartesianproduct(r) - prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) - return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) +# Reverse map from CartesianPair to linear index in the range. +Base.@propagate_inbounds function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair) + indsa, indsb = kroneckerfactors(inds) + ia, ib = kroneckerfactors(i) + i′ = CartesianIndex(findfirst(==(ib), indsb), findfirst(==(ia), indsa)) + i_linear = LinearIndices(reverse(length.(kroneckerfactors(inds))))[i′] + return inds[i_linear] end function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair) - return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) + indsa, indsb = kroneckerfactors(inds) + ia, ib = kroneckerfactors(i) + return checkindex(Bool, indsa, ia) && checkindex(Bool, indsb, ib) end -const CartesianProductOneTo{T, P <: CartesianProduct, R <: Base.OneTo{T}} = CartesianProductUnitRange{ - T, P, R, -} -Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) -Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices -Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) -function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct) - prod = cartesianproduct(a) - prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)] - return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I)) +# AbstractUnitRange interface +# --------------------------- +Base.first(r::CartesianProductUnitRange) = first(r.range) +Base.last(r::CartesianProductUnitRange) = last(r.range) + + +# Broadcasting +# ------------ +for f in (:+, :-) + @eval BC.broadcasted(::BC.DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer) = + cartesianrange(kroneckerfactors(r)..., $f.(unproduct(r), x)) + @eval BC.broadcasted(::BC.DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange) = + cartesianrange(kroneckerfactors(r)..., $f.(x, unproduct(r))) end -# Reverse map from CartesianPair to linear index in the range. -function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair) - i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds))) - return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]] +function BC.axistype(r1::CartesianProductUnitRange, r2::CartesianProductUnitRange) + r1a, r1b = kroneckerfactors(r1) + r2a, r2b = kroneckerfactors(r2) + return cartesianrange(splat(BC.axistype).(((r1a, r2a), (r1b, r2b), (unproduct(r1), unproduct(r2))))...) end -using Base.Broadcast: DefaultArrayStyle -for f in (:+, :-) - @eval begin - function Broadcast.broadcasted( - ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer - ) - return CartesianProductUnitRange(cartesianproduct(r), $f.(unproduct(r), x)) - end - function Broadcast.broadcasted( - ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange - ) - return CartesianProductUnitRange(cartesianproduct(r), $f.(x, unproduct(r))) - end - end + +# Show +# ---- +function Base.show(io::IO, ab::Union{CartesianPair, CartesianProduct}) + a, b = kroneckerfactors(ab) + show(io, a) + print(io, " × ") + show(io, b) + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", ab::Union{CartesianPair, CartesianProduct}) + a, b = kroneckerfactors(ab) + compact = get(io, :compact, true)::Bool + show(io, mime, a) + compact || println(io) + print(io, " × ") + compact || println(io) + show(io, mime, b) + return nothing +end + +function Base.show(io::IO, ab::CartesianProductVector) + a, b = kroneckerfactors(ab) + print(io, "cartesianproduct(") + show(io, a) + print(io, ", ") + show(io, b) + print(io, ", ") + show(io, unproduct(ab)) + print(io, ")") + return nothing end -using Base.Broadcast: axistype -function Base.Broadcast.axistype( - r1::CartesianProductUnitRange, r2::CartesianProductUnitRange - ) - prod = axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2)) - range = axistype(unproduct(r1), unproduct(r2)) - return cartesianrange(prod, range) +function Base.show(io::IO, ab::CartesianProductUnitRange) + a, b = kroneckerfactors(ab) + range = unproduct(ab) + print(io, "cartesianrange(") + show(io, a) + print(io, ", ") + show(io, b) + print(io, ", ") + show(io, unproduct(ab)) + print(io, ")") + return nothing +end +function Base.show(io::IO, ab::CartesianProductOneTo) + a, b = kroneckerfactors(ab) + print(io, "(") + show(io, a) + print(io, " × ") + show(io, b) + print(io, ")") + return nothing end diff --git a/src/fillarrays.jl b/src/fillarrays.jl index 0d8fd79..a6e8b9c 100644 --- a/src/fillarrays.jl +++ b/src/fillarrays.jl @@ -5,7 +5,7 @@ function FillArrays.fillsimilar( CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, }, ) where {T} - return Zeros{T}(arg1.(ax)) ⊗ Zeros{T}(arg2.(ax)) + return Zeros{T}(kroneckerfactors.(ax, 1)) ⊗ Zeros{T}(kroneckerfactors.(ax, 2)) end # Simplification rules similar to those for FillArrays.jl: diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 3d5d5af..d52c4e7 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -9,27 +9,6 @@ abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end const AbstractKroneckerVector{T} = AbstractKroneckerArray{T, 1} const AbstractKroneckerMatrix{T} = AbstractKroneckerArray{T, 2} -@doc """ - arg1(AB::AbstractKroneckerArray{T, N}) - -Extract the first factor (`A`) of the Kronecker array `AB = A ⊗ B`. -""" arg1 - -@doc """ - arg2(AB::AbstractKroneckerArray{T, N}) - -Extract the second factor (`B`) of the Kronecker array `AB = A ⊗ B`. -""" arg2 - -arg1type(x::AbstractKroneckerArray) = arg1type(typeof(x)) -arg1type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg1type`.") -arg2type(x::AbstractKroneckerArray) = arg2type(typeof(x)) -arg2type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg2type`.") - -arguments(a::AbstractKroneckerArray) = (arg1(a), arg2(a)) -arguments(a::AbstractKroneckerArray, n::Int) = arguments(a)[n] -argument_types(a::AbstractKroneckerArray) = argument_types(typeof(a)) - function unwrap_array(a::AbstractArray) p = parent(a) p ≡ a && return a @@ -50,71 +29,64 @@ end function _convert(A::Type{<:AbstractArray}, a::AbstractArray) return convert(A, a) end -using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag _construct(A::Type{<:Diagonal}, a::AbstractMatrix) = A(diag(a)) function _convert(A::Type{<:Diagonal}, a::AbstractMatrix) LinearAlgebra.checksquare(a) return isdiag(a) ? _construct(A, a) : throw(InexactError(:convert, A, a)) end -struct KroneckerArray{T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} <: +struct KroneckerArray{T, N, A <: AbstractArray{T, N}, B <: AbstractArray{T, N}} <: AbstractKroneckerArray{T, N} - arg1::A1 - arg2::A2 + a::A + b::B end -function KroneckerArray(a1::AbstractArray, a2::AbstractArray) - if ndims(a1) != ndims(a2) - throw( - ArgumentError("Kronecker product requires arrays of the same number of dimensions.") - ) - end - elt = promote_type(eltype(a1), eltype(a2)) - return _convert(AbstractArray{elt}, a1) ⊗ _convert(AbstractArray{elt}, a2) +function KroneckerArray(a::AbstractArray, b::AbstractArray) + ndims(a) == ndims(b) || + throw(DimensionMismatch("Kronecker product requires arrays of the same number of dimensions.")) + elt = promote_type(eltype(a), eltype(b)) + return _convert(AbstractArray{elt}, a) ⊗ _convert(AbstractArray{elt}, b) end -const KroneckerMatrix{T, A1 <: AbstractMatrix{T}, A2 <: AbstractMatrix{T}} = KroneckerArray{ - T, 2, A1, A2, -} -const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = KroneckerArray{ - T, 1, A1, A2, -} -@inline arg1(a::KroneckerArray) = getfield(a, :arg1) -@inline arg2(a::KroneckerArray) = getfield(a, :arg2) -arg1type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A1 -arg2type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A2 +const KroneckerMatrix{T, A <: AbstractMatrix{T}, B <: AbstractMatrix{T}} = + KroneckerArray{T, 2, A, B} +const KroneckerVector{T, A <: AbstractVector{T}, B <: AbstractVector{T}} = + KroneckerArray{T, 1, A, B} + +kroneckerfactors(ab::KroneckerArray) = (ab.a, ab.b) +kroneckerfactortypes(::Type{KroneckerArray{T, N, A, B}}) where {T, N, A, B} = (A, B) -argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2) -function mutate_active_args!(f!, f, dest, src) - (isactive(arg1(dest)) || isactive(arg2(dest))) || +function mutate_active_args!(f!, f, C, A) + Ca, Cb = kroneckerfactors(C) + Aa, Ab = kroneckerfactors(A) + (isactive(Ca) || isactive(Cb)) || error("Can't mutate immutable KroneckerArray.") - if isactive(arg1(dest)) - f!(arg1(dest), arg1(src)) + if isactive(Ca) + f!(Ca, Aa) else - arg1(dest) == f(arg1(src)) || error("Immutable arguments aren't equal.") + Ca == f(Aa) || error("Immutable arguments aren't equal.") end - if isactive(arg2(dest)) - f!(arg2(dest), arg2(src)) + if isactive(Cb) + f!(Cb, Ab) else - arg2(dest) == f(arg2(src)) || error("Immutable arguments aren't equal.") + Cb == f(Ab) || error("Immutable arguments aren't equal.") end - return dest -end - -using Adapt: Adapt, adapt -function Adapt.adapt_structure(to, a::AbstractKroneckerArray) - # TODO: Is this a good definition? It is similar to - # the definition of `similar`. - return if isactive(arg1(a)) == isactive(arg2(a)) - adapt(to, arg1(a)) ⊗ adapt(to, arg2(a)) - elseif isactive(arg1(a)) - adapt(to, arg1(a)) ⊗ arg2(a) - elseif isactive(arg2(a)) - arg1(a) ⊗ adapt(to, arg2(a)) + return C +end + +function Adapt.adapt_structure(to, ab::AbstractKroneckerArray) + # TODO: Is this a good definition? It is similar to the definition of `similar`. + a, b = kroneckerfactors(ab) + return if isactive(a) == isactive(b) + Adapt.adapt(to, a) ⊗ Adapt.adapt(to, b) + elseif isactive(a) + Adapt.adapt(to, a) ⊗ b + elseif isactive(b) + a ⊗ Adapt.adapt(to, b) end end -Base.copy(a::AbstractKroneckerArray) = copy(arg1(a)) ⊗ copy(arg2(a)) +Base.copy(a::AbstractKroneckerArray) = ⊗(copy.(kroneckerfactors(a))...) function Base.copy!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray) return mutate_active_args!(copy!, copy, dest, src) end @@ -126,10 +98,11 @@ function Base.copyto!(dest::AbstractKroneckerArray{<:Any, N}, src::AbstractKrone end function Base.convert( - ::Type{KroneckerArray{T, N, A1, A2}}, a::AbstractKroneckerArray - )::KroneckerArray{T, N, A1, A2} where {T, N, A1, A2} - typeof(a) === KroneckerArray{T, N, A1, A2} && return a - return KroneckerArray(_convert(A1, arg1(a)), _convert(A2, arg2(a))) + ::Type{KroneckerArray{T, N, A, B}}, ab::AbstractKroneckerArray + )::KroneckerArray{T, N, A, B} where {T, N, A, B} + typeof(ab) === KroneckerArray{T, N, A, B} && return ab + a, b = kroneckerfactors(ab) + return KroneckerArray(_convert(A, a), _convert(B, b)) end # Promote the element type if needed. @@ -138,86 +111,76 @@ end maybe_promot_eltype(a, elt) = eltype(a) <: elt ? a : elt.(a) function Base.similar( - a::AbstractKroneckerArray, + ab::AbstractKroneckerArray, elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, - }, + axs::Tuple{CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}} ) # TODO: Is this a good definition? - return if isactive(arg1(a)) == isactive(arg2(a)) - similar(arg1(a), elt, arg1.(axs)) ⊗ similar(arg2(a), elt, arg2.(axs)) - elseif isactive(arg1(a)) - @assert arg2.(axs) == axes(arg2(a)) - similar(arg1(a), elt, arg1.(axs)) ⊗ maybe_promot_eltype(arg2(a), elt) - elseif isactive(arg2(a)) - @assert arg1.(axs) == axes(arg1(a)) - maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt, arg2.(axs)) + a, b = kroneckerfactors(ab) + return if isactive(a) == isactive(b) + similar(a, elt, kroneckerfactors.(axs, 1)) ⊗ similar(b, elt, kroneckerfactors.(axs, 2)) + elseif isactive(a) + @assert kroneckerfactors.(axs, 2) == axes(b) + similar(a, elt, kroneckerfactors.(axs, 1)) ⊗ maybe_promot_eltype(b, elt) + elseif isactive(b) + @assert kroneckerfactors.(axs, 1) == axes(a) + maybe_promot_eltype(a, elt) ⊗ similar(b, elt, kroneckerfactors.(axs, 2)) end end -function Base.similar(a::AbstractKroneckerArray, elt::Type) +function Base.similar(ab::AbstractKroneckerArray, elt::Type) # TODO: Is this a good definition? - return if isactive(arg1(a)) == isactive(arg2(a)) - similar(arg1(a), elt) ⊗ similar(arg2(a), elt) - elseif isactive(arg1(a)) - similar(arg1(a), elt) ⊗ maybe_promot_eltype(arg2(a), elt) - elseif isactive(arg2(a)) - maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt) + a, b = kroneckerfactors(ab) + return if isactive(a) == isactive(b) + similar(a, elt) ⊗ similar(b, elt) + elseif isactive(a) + similar(a, elt) ⊗ maybe_promot_eltype(b, elt) + elseif isactive(b) + maybe_promot_eltype(a, elt) ⊗ similar(b, elt) end end -function Base.similar(a::AbstractKroneckerArray) +function Base.similar(ab::AbstractKroneckerArray) # TODO: Is this a good definition? - return if isactive(arg1(a)) == isactive(arg2(a)) - similar(arg1(a)) ⊗ similar(arg2(a)) - elseif isactive(arg1(a)) - similar(arg1(a)) ⊗ arg2(a) - elseif isactive(arg2(a)) - arg1(a) ⊗ similar(arg2(a)) + a, b = kroneckerfactors(ab) + return if isactive(a) == isactive(b) + similar(a) ⊗ similar(b) + elseif isactive(a) + similar(a) ⊗ b + elseif isactive(b) + a ⊗ similar(b) end end function Base.similar( a::AbstractArray, elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, - }, + axs::Tuple{CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}} ) - return similar(a, elt, map(arg1, axs)) ⊗ similar(a, elt, map(arg2, axs)) + return similar(a, elt, kroneckerfactors.(axs, 1)) ⊗ similar(a, elt, kroneckerfactors.(axs, 2)) end function Base.similar( ::Type{ArrayT}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, - }, + axs::Tuple{CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}} ) where {ArrayT <: AbstractKroneckerArray} - A1, A2 = arg1type(ArrayT), arg2type(ArrayT) - return similar(A1, map(arg1, axs)) ⊗ similar(A2, map(arg2, axs)) + A, B = kroneckerfactortypes(ArrayT) + return similar(A, kroneckerfactors.(axs, 1)) ⊗ similar(B, kroneckerfactors.(axs, 2)) end -function Base.similar( - ::Type{ArrayT}, sz::Tuple{Int, Vararg{Int}} - ) where {ArrayT <: AbstractKroneckerArray} - A1, A2 = arg1type(ArrayT), arg2type(ArrayT) - return similar(promote_type(A1, A2), sz) +function Base.similar(::Type{ArrayT}, sz::Tuple{Int, Vararg{Int}}) where {ArrayT <: AbstractKroneckerArray} + A, B = kroneckerfactortypes(ArrayT) + return similar(promote_type(A, B), sz) end function Base.similar( arrayt::Type{<:AbstractArray}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, - }, + axs::Tuple{CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}} ) - return similar(arrayt, map(arg1, axs)) ⊗ similar(arrayt, map(arg2, axs)) + return similar(arrayt, kroneckerfactors.(axs, 1)) ⊗ similar(arrayt, kroneckerfactors.(axs, 2)) end -function Base.permutedims(a::AbstractKroneckerArray, perm) - return permutedims(arg1(a), perm) ⊗ permutedims(arg2(a), perm) -end -using DerivableInterfaces: DerivableInterfaces, permuteddims -function DerivableInterfaces.permuteddims(a::AbstractKroneckerArray, perm) - return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm) -end +Base.permutedims(ab::AbstractKroneckerArray, perm) = + ⊗(permutedims.(kroneckerfactors(ab), (perm,))...) +DerivableInterfaces.permuteddims(ab::AbstractKroneckerArray, perm) = + ⊗(DerivableInterfaces.permuteddims.(kroneckerfactors(ab), (perm,))...) function Base.permutedims!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray, perm) return mutate_active_args!( @@ -250,54 +213,51 @@ kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2) kron_nd(a1::AbstractVector, a2::AbstractVector) = kron(a1, a2) # Eagerly collect arguments to make more general on GPU. -Base.collect(a::AbstractKroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a))) -Base.collect(T::Type, a::AbstractKroneckerArray) = kron_nd(collect(T, arg1(a)), collect(T, arg2(a))) +Base.collect(ab::AbstractKroneckerArray) = kron_nd(collect.(kroneckerfactors(ab))...) +Base.collect(T::Type, ab::AbstractKroneckerArray) = kron_nd(collect.(T, kroneckerfactors(ab))...) -function Base.zero(a::AbstractKroneckerArray) - return if isactive(arg1(a)) == isactive(arg2(a)) +function Base.zero(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + return if isactive(a) == isactive(b) # TODO: Maybe this should zero both arguments? # This is how `a * false` would behave. - arg1(a) ⊗ zero(arg2(a)) - elseif isactive(arg1(a)) - zero(arg1(a)) ⊗ arg2(a) - elseif isactive(arg2(a)) - arg1(a) ⊗ zero(arg2(a)) + a ⊗ zero(b) + elseif isactive(a) + zero(a) ⊗ b + elseif isactive(b) + a ⊗ zero(b) end end -using DerivableInterfaces: DerivableInterfaces, zero! -function DerivableInterfaces.zero!(a::AbstractKroneckerArray) - (isactive(arg1(a)) || isactive(arg2(a))) || - error("Can't mutate immutable KroneckerArray.") - isactive(arg1(a)) && zero!(arg1(a)) - isactive(arg2(a)) && zero!(arg2(a)) - return a +function DerivableInterfaces.zero!(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + (isactive(a) || isactive(b)) || error("Can't mutate immutable KroneckerArray.") + isactive(a) && DerivableInterfaces.zero!(a) + isactive(b) && DerivableInterfaces.zero!(b) + return ab end -function Base.Array{T, N}(a::AbstractKroneckerArray{S, N}) where {T, S, N} - return convert(Array{T, N}, collect(a)) -end +Base.Array{T, N}(a::AbstractKroneckerArray) where {T, N} = convert(Array{T, N}, collect(a)) -Base.size(a::AbstractKroneckerArray) = size(arg1(a)) .* size(arg2(a)) +Base.size(ab::AbstractKroneckerArray) = broadcast(*, size.(kroneckerfactors(ab))...) -function Base.axes(a::AbstractKroneckerArray) - return ntuple(ndims(a)) do dim - return CartesianProductUnitRange( - axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim)) - ) - end +function Base.axes(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + return axes(a) .× axes(b) end -function Base.print_array(io::IO, a::KroneckerArray) - Base.print_array(io, arg1(a)) +function Base.print_array(io::IO, ab::KroneckerArray) + a, b = kroneckerfactors(ab) + Base.print_array(io, a) println(io, "\n ⊗") - Base.print_array(io, arg2(a)) + Base.print_array(io, b) return nothing end -function Base.show(io::IO, a::KroneckerArray) - show(io, arg1(a)) +function Base.show(io::IO, ab::KroneckerArray) + a, b = kroneckerfactors(ab) + show(io, a) print(io, " ⊗ ") - show(io, arg2(a)) + show(io, b) return nothing end @@ -306,73 +266,78 @@ end ⊗(a1::Number, a2::AbstractArray) = a1 * a2 ⊗(a1::AbstractArray, a2::Number) = a1 * a2 -function Base.getindex(a::KroneckerArray, i::Integer) - return a[CartesianIndices(a)[i]] -end - -using GPUArraysCore: GPUArraysCore -function Base.getindex(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Integer, N}) where {N} +function Base.getindex(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Int, N}) where {N} GPUArraysCore.assertscalar("getindex") I′ = ntuple(Val(N)) do dim - return cartesianproduct(axes(a, dim))[I[dim]] + return cartesianproduct(kroneckerfactors(axes(a, dim))...)[I[dim]] end return a[I′...] end # Indexing logic. function Base.to_indices( - a::AbstractKroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg} + ab::AbstractKroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg} ) - I1 = to_indices(arg1(a), arg1.(inds), arg1.(I)) - I2 = to_indices(arg2(a), arg2.(inds), arg2.(I)) + a, b = kroneckerfactors(ab) + I1 = to_indices(a, kroneckerfactors.(inds, 1), kroneckerfactors.(I, 1)) + I2 = to_indices(b, kroneckerfactors.(inds, 2), kroneckerfactors.(I, 2)) return I1 .× I2 end function Base.getindex( - a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N} + ab::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct, CartesianProductUnitRange}, N} ) where {N} - I′ = to_indices(a, I) - return arg1(a)[arg1.(I′)...] ⊗ arg2(a)[arg2.(I′)...] + I′ = to_indices(ab, I) + a, b = kroneckerfactors(ab) + return a[kroneckerfactors.(I′, 1)...] ⊗ b[kroneckerfactors.(I′, 2)...] end + # Fix ambigiuity error. -Base.getindex(a::AbstractKroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[] +Base.getindex(ab::AbstractKroneckerArray{<:Any, 0}) = *(getindex.(kroneckerfactors(ab))...) + +kroneckerfactors(::Colon) = ((:), (:)) +kroneckerfactors(::Base.Slice) = ((:), (:)) -arg1(::Colon) = (:) -arg2(::Colon) = (:) -arg1(::Base.Slice) = (:) -arg2(::Base.Slice) = (:) function Base.view( - a::AbstractKroneckerArray{<:Any, N}, + ab::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N}, ) where {N} - return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) + a, b = kroneckerfactors(ab) + Ia = kroneckerfactors.(I, 1) + Ib = kroneckerfactors.(I, 2) + return view(a, Ia...) ⊗ view(b, Ib...) end -function Base.view(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N} - return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) +function Base.view(ab::AbstractKroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N} + a, b = kroneckerfactors(ab) + Ia = kroneckerfactors.(I, 1) + Ib = kroneckerfactors.(I, 2) + return view(a, Ia...) ⊗ view(b, Ib...) end # Fix ambigiuity error. -Base.view(a::AbstractKroneckerArray{<:Any, 0}) = view(arg1(a)) ⊗ view(arg2(a)) +Base.view(ab::AbstractKroneckerArray{<:Any, 0}) = ⊗(view.(kroneckerfactors(ab))...) -function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray) - return arg1(a) == arg1(b) && arg2(a) == arg2(b) +function Base.:(==)(ab::AbstractKroneckerArray, cd::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + c, d = kroneckerfactors(cd) + return a == c && b == d end # norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2) # = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2)) function dist_kronecker(a::AbstractKroneckerArray, b::AbstractKroneckerArray) - a1, a2 = arg1(a), arg2(a) - b1, b2 = arg1(b), arg2(b) + a1, a2 = kroneckerfactors(a) + b1, b2 = kroneckerfactors(b) diff1 = a1 - b1 diff2 = a2 - b2 # x = (a1 - b1) ⊗ a2 # y = b1 ⊗ (a2 - b2) # z = (a1 - b1) ⊗ (a2 - b2) - xx = norm(diff1)^2 * norm(a2)^2 - yy = norm(b1)^2 * norm(diff2)^2 - zz = norm(diff1)^2 * norm(diff2)^2 - xy = real(dot(diff1, b1) * dot(a2, diff2)) - xz = real(dot(diff1, diff1) * dot(a2, diff2)) - yz = real(dot(b1, diff1) * dot(diff2, diff2)) + xx = LinearAlgebra.norm(diff1)^2 * LinearAlgebra.norm(a2)^2 + yy = LinearAlgebra.norm(b1)^2 * LinearAlgebra.norm(diff2)^2 + zz = LinearAlgebra.norm(diff1)^2 * LinearAlgebra.norm(diff2)^2 + xy = real(LinearAlgebra.dot(diff1, b1) * LinearAlgebra.dot(a2, diff2)) + xz = real(LinearAlgebra.dot(diff1, diff1) * LinearAlgebra.dot(a2, diff2)) + yz = real(LinearAlgebra.dot(b1, diff1) * LinearAlgebra.dot(diff2, diff2)) # `abs` is used in case there are negative values due to floating point roundoff errors. return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz))) end @@ -382,12 +347,12 @@ function Base.isapprox( a::AbstractKroneckerArray, b::AbstractKroneckerArray; atol::Real = 0, rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol), ) - a1, a2 = arg1(a), arg2(a) - b1, b2 = arg1(b), arg2(b) + a1, a2 = kroneckerfactors(a) + b1, b2 = kroneckerfactors(b) if a1 == b1 - return isapprox(a2, b2; atol = atol / norm(a1), rtol) + return isapprox(a2, b2; atol = atol / LinearAlgebra.norm(a1), rtol) elseif a2 == b2 - return isapprox(a1, b1; atol = atol / norm(a2), rtol) + return isapprox(a1, b1; atol = atol / LinearAlgebra.norm(a2), rtol) else # This could be defined as: # ```julia @@ -404,82 +369,81 @@ function Base.isapprox( end end -function Base.iszero(a::AbstractKroneckerArray) - return iszero(arg1(a)) || iszero(arg2(a)) +function Base.iszero(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + return iszero(a) || iszero(b) end -function Base.isreal(a::KroneckerArray) - return isreal(arg1(a)) && isreal(arg2(a)) +function Base.isreal(ab::KroneckerArray) + a, b = kroneckerfactors(ab) + return isreal(a) && isreal(b) end -using DiagonalArrays: DiagonalArrays, diagonal -function DiagonalArrays.diagonal(a::KroneckerArray) - return diagonal(arg1(a)) ⊗ diagonal(arg2(a)) -end +DiagonalArrays.diagonal(ab::KroneckerArray) = ⊗(DiagonalArrays.diagonal.(kroneckerfactors(ab))...) -Base.real(a::AbstractKroneckerArray{<:Real}) = a -function Base.real(a::AbstractKroneckerArray) - if iszero(imag(arg1(a))) || iszero(imag(arg2(a))) - return real(arg1(a)) ⊗ real(arg2(a)) - elseif iszero(real(arg1(a))) || iszero(real(arg2(a))) - return -(imag(arg1(a)) ⊗ imag(arg2(a))) +Base.real(ab::AbstractKroneckerArray{<:Real}) = ab +# TODO: the extra checks here are probably as expensive as the general case +function Base.real(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + if iszero(imag(a)) || iszero(imag(b)) + return real(a) ⊗ real(b) + elseif iszero(real(a)) || iszero(real(b)) + return -(imag(a) ⊗ imag(b)) end - return real(arg1(a)) ⊗ real(arg2(a)) - imag(arg1(a)) ⊗ imag(arg2(a)) -end -Base.imag(a::AbstractKroneckerArray{<:Real}) = zero(a) -function Base.imag(a::AbstractKroneckerArray) - if iszero(imag(arg1(a))) || iszero(real(arg2(a))) - return real(arg1(a)) ⊗ imag(arg2(a)) - elseif iszero(real(arg1(a))) || iszero(imag(arg2(a))) - return imag(arg1(a)) ⊗ real(arg2(a)) - end - return real(arg1(a)) ⊗ imag(arg2(a)) + imag(arg1(a)) ⊗ real(arg2(a)) + return real(a) ⊗ real(b) - imag(a) ⊗ imag(b) end -for f in [:transpose, :adjoint, :inv] - @eval begin - function Base.$f(a::AbstractKroneckerArray) - return $f(arg1(a)) ⊗ $f(arg2(a)) - end +Base.imag(ab::AbstractKroneckerArray{<:Real}) = zero(ab) +# TODO: the extra checks here are probably as expensive as the general case +function Base.imag(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + if iszero(imag(a)) || iszero(real(b)) + return real(a) ⊗ imag(b) + elseif iszero(real(a)) || iszero(imag(b)) + return imag(a) ⊗ real(b) end + return real(a) ⊗ imag(b) + imag(a) ⊗ real(b) +end + +for f in (:transpose, :adjoint, :inv) + @eval Base.$f(ab::AbstractKroneckerArray) = ⊗($f.(kroneckerfactors(ab))...) end function Base.reshape( - a::AbstractKroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}} + ab::AbstractKroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}} ) - return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax)) + a, b = kroneckerfactors(ab) + return reshape(a, kroneckerfactors.(ax, 1)) ⊗ reshape(b, kroneckerfactors.(ax, 2)) end using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted -struct KroneckerStyle{N, A1, A2} <: AbstractArrayStyle{N} end -arg1(::Type{<:KroneckerStyle{<:Any, A1}}) where {A1} = A1 -arg1(style::KroneckerStyle) = arg1(typeof(style)) -arg2(::Type{<:KroneckerStyle{<:Any, <:Any, A2}}) where {A2} = A2 -arg2(style::KroneckerStyle) = arg2(typeof(style)) -function KroneckerStyle{N}(a1::BroadcastStyle, a2::BroadcastStyle) where {N} - return KroneckerStyle{N, a1, a2}() -end -function KroneckerStyle(a1::AbstractArrayStyle{N}, a2::AbstractArrayStyle{N}) where {N} - return KroneckerStyle{N}(a1, a2) -end -function KroneckerStyle{N, A1, A2}(v::Val{M}) where {N, A1, A2, M} - return KroneckerStyle{M, typeof(A1)(v), typeof(A2)(v)}() -end -function Base.BroadcastStyle(::Type{T}) where {T <: AbstractKroneckerArray} - return KroneckerStyle{ndims(T)}(BroadcastStyle(arg1type(T)), BroadcastStyle(arg2type(T))) -end + +struct KroneckerStyle{N, A, B} <: BC.AbstractArrayStyle{N} end + +kroneckerfactors(::Type{KroneckerStyle{N, A, B}}) where {N, A, B} = (A, B) +kroneckerfactors(style::KroneckerStyle) = kroneckerfactors(typeof(style)) + +KroneckerStyle{N}(A::BroadcastStyle, B::BroadcastStyle) where {N} = KroneckerStyle{N, A, B}() +KroneckerStyle(A::AbstractArrayStyle{N}, B::AbstractArrayStyle{N}) where {N} = KroneckerStyle{N}(A, B) +KroneckerStyle{N, A, B}(v::Val{M}) where {N, A, B, M} = KroneckerStyle{M, typeof(A)(v), typeof(B)(v)}() + +Base.BroadcastStyle(::Type{T}) where {T <: AbstractKroneckerArray} = + KroneckerStyle{ndims(T)}(BroadcastStyle.(kroneckerfactortypes(T))...) function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N} - style_a = BroadcastStyle(arg1(style1), arg1(style2)) - (style_a isa Broadcast.Unknown) && return Broadcast.Unknown() - style_b = BroadcastStyle(arg2(style1), arg2(style2)) - (style_b isa Broadcast.Unknown) && return Broadcast.Unknown() + A1, B1 = kroneckerfactors(style1) + A2, B2 = kroneckerfactors(style2) + style_a = BroadcastStyle(A1, A2) + (style_a isa BC.Unknown) && return BC.Unknown() + style_b = BroadcastStyle(B1, B2) + (style_b isa BC.Unknown) && return BC.Unknown() return KroneckerStyle{N}(style_a, style_b) end + function Base.similar( - bc::Broadcasted{<:KroneckerStyle{N, A1, A2}}, elt::Type, ax - ) where {N, A1, A2} - bc_a = Broadcasted(A1, bc.f, arg1.(bc.args), arg1.(ax)) - bc_b = Broadcasted(A2, bc.f, arg2.(bc.args), arg2.(ax)) + bc::BC.Broadcasted{<:KroneckerStyle{N, A, B}}, elt::Type, ax + ) where {N, A, B} + bc_a = BC.Broadcasted(A, bc.f, kroneckerfactors.(bc.args, 1), kroneckerfactors.(ax, 1)) a = similar(bc_a, elt) + bc_b = BC.Broadcasted(B, bc.f, kroneckerfactors.(bc.args, 1), kroneckerfactors.(ax, 2)) b = similar(bc_b, elt) return a ⊗ b end @@ -492,16 +456,14 @@ function Base.map!(f, dest::AbstractKroneckerArray, a1::AbstractKroneckerArray, return dest end -using MapBroadcast: MapBroadcast, LinearCombination, Summed function KroneckerBroadcast(a::Summed{<:KroneckerStyle}) f = LinearCombination(a) args = MapBroadcast.arguments(a) - arg1s = arg1.(args) - arg2s = arg2.(args) + arg1s = kroneckerfactors.(args, 1) + arg2s = kroneckerfactors.(args, 2) arg1_isunique = allequal(arg1s) arg2_isunique = allequal(arg2s) - (arg1_isunique || arg2_isunique) || - error("This operation doesn't preserve the Kronecker structure.") + (arg1_isunique || arg2_isunique) || error("This operation doesn't preserve the Kronecker structure.") broadcast_arg = if arg1_isunique && arg2_isunique isactive(first(arg1s)) ? 1 : 2 elseif arg1_isunique @@ -510,125 +472,89 @@ function KroneckerBroadcast(a::Summed{<:KroneckerStyle}) 1 end return if broadcast_arg == 1 - broadcasted(f, arg1s...) ⊗ first(arg2s) + BC.broadcasted(f, arg1s...) ⊗ first(arg2s) elseif broadcast_arg == 2 - first(arg1s) ⊗ broadcasted(f, arg2s...) + first(arg1s) ⊗ BC.broadcasted(f, arg2s...) end end -function Base.copy(a::Summed{<:KroneckerStyle}) - return copy(KroneckerBroadcast(a)) -end -function Base.copyto!(dest::AbstractKroneckerArray, a::Summed{<:KroneckerStyle}) - return copyto!(dest, KroneckerBroadcast(a)) -end +Base.copy(a::Summed{<:KroneckerStyle}) = copy(KroneckerBroadcast(a)) +Base.copyto!(dest::AbstractKroneckerArray, a::Summed{<:KroneckerStyle}) = copyto!(dest, KroneckerBroadcast(a)) -function Broadcast.broadcasted(::KroneckerStyle, f, as...) - return error("Arbitrary broadcasting not supported for KroneckerArray.") -end +BC.broadcasted(::KroneckerStyle, f, as...) = error("Arbitrary broadcasting not supported for KroneckerStyle.") # Linear operations. -function Broadcast.broadcasted(::KroneckerStyle, ::typeof(+), a1, a2) - return Summed(a1) + Summed(a2) -end -function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a1, a2) - return Summed(a1) - Summed(a2) -end -function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), c::Number, a) - return c * Summed(a) -end -function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number) - return Summed(a) * c -end +BC.broadcasted(::KroneckerStyle, ::typeof(+), a1, a2) = Summed(a1) + Summed(a2) +BC.broadcasted(::KroneckerStyle, ::typeof(-), a1, a2) = Summed(a1) - Summed(a2) +BC.broadcasted(::KroneckerStyle, ::typeof(*), c::Number, a) = c * Summed(a) +BC.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number) = Summed(a) * c + # Fix ambiguity error. -function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a::Number, b::Number) - return a * b -end -function Broadcast.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number) - return Summed(a) / c -end -function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a) - return -Summed(a) -end +BC.broadcasted(::KroneckerStyle, ::typeof(*), a::Number, b::Number) = a * b +BC.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number) = Summed(a) / c +BC.broadcasted(::KroneckerStyle, ::typeof(-), a) = -Summed(a) # Rewrite rules to canonicalize broadcast expressions. -function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix1{typeof(*), <:Number}, a) - return broadcasted(style, *, f.x, a) -end -function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(*), <:Number}, a) - return broadcasted(style, *, a, f.x) -end -function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(/), <:Number}, a) - return broadcasted(style, /, a, f.x) -end +BC.broadcasted(style::KroneckerStyle, f::Base.Fix1{typeof(*), <:Number}, a) = BC.broadcasted(style, *, f.x, a) +BC.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(*), <:Number}, a) = BC.broadcasted(style, *, a, f.x) +BC.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(/), <:Number}, a) = BC.broadcasted(style, /, a, f.x) # Compatibility with MapBroadcast.jl. -using MapBroadcast: MapBroadcast, MapFunction -function Base.broadcasted( - style::KroneckerStyle, f::MapFunction{typeof(*), <:Tuple{<:Number, MapBroadcast.Arg}}, a - ) - return broadcasted(style, *, f.args[1], a) -end -function Base.broadcasted( - style::KroneckerStyle, f::MapFunction{typeof(*), <:Tuple{MapBroadcast.Arg, <:Number}}, a - ) - return broadcasted(style, *, a, f.args[2]) -end -function Base.broadcasted( - style::KroneckerStyle, f::MapFunction{typeof(/), <:Tuple{MapBroadcast.Arg, <:Number}}, a - ) - return broadcasted(style, /, a, f.args[2]) -end +BC.broadcasted(style::KroneckerStyle, f::MapFunction{typeof(*), <:Tuple{<:Number, MapBroadcast.Arg}}, a) = + BC.broadcasted(style, *, f.args[1], a) +BC.broadcasted(style::KroneckerStyle, f::MapFunction{typeof(*), <:Tuple{MapBroadcast.Arg, <:Number}}, a) = + BC.broadcasted(style, *, a, f.args[2]) +BC.broadcasted(style::KroneckerStyle, f::MapFunction{typeof(/), <:Tuple{MapBroadcast.Arg, <:Number}}, a) = + BC.broadcasted(style, /, a, f.args[2]) + # Use to determine the element type of KroneckerBroadcasted. _eltype(x) = eltype(x) -_eltype(x::Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...) +_eltype(x::BC.Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...) -using Base.Broadcast: broadcasted # Represents broadcast operations that can be applied Kronecker-wise, # i.e. independently to each argument of the Kronecker product. # Note that not all broadcast operations can be mapped to this. -struct KroneckerBroadcasted{A1, A2} - arg1::A1 - arg2::A2 -end -@inline arg1(a::KroneckerBroadcasted) = getfield(a, :arg1) -@inline arg2(a::KroneckerBroadcasted) = getfield(a, :arg2) -⊗(a1::Broadcasted, a2::Broadcasted) = KroneckerBroadcasted(a1, a2) -⊗(a1::Broadcasted, a2) = KroneckerBroadcasted(a1, a2) -⊗(a1, a2::Broadcasted) = KroneckerBroadcasted(a1, a2) -Broadcast.materialize(a::KroneckerBroadcasted) = copy(a) -Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) -Broadcast.broadcastable(a::KroneckerBroadcasted) = a -Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a)) +struct KroneckerBroadcasted{A, B} + a::A + b::B +end + +kroneckerfactors(ab::KroneckerBroadcasted) = ab.a, ab.b +kroneckerfactortypes(::Type{KroneckerBroadcasted{A, B}}) where {A, B} = (A, B) + +⊗(a1::BC.Broadcasted, a2::BC.Broadcasted) = KroneckerBroadcasted(a1, a2) +⊗(a1::BC.Broadcasted, a2) = KroneckerBroadcasted(a1, a2) +⊗(a1, a2::BC.Broadcasted) = KroneckerBroadcasted(a1, a2) + +BC.materialize(a::KroneckerBroadcasted) = copy(a) +BC.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) +BC.broadcastable(a::KroneckerBroadcasted) = a + +Base.copy(ab::KroneckerBroadcasted) = ⊗(copy.(kroneckerfactors(ab))...) function Base.copyto!(dest::AbstractKroneckerArray, src::KroneckerBroadcasted) return mutate_active_args!(copyto!, copy, dest, src) end -function Base.eltype(a::KroneckerBroadcasted) - a1 = arg1(a) - a2 = arg2(a) - return Base.promote_op(*, _eltype(a1), _eltype(a2)) +function Base.eltype(ab::KroneckerBroadcasted) + a, b = kroneckerfactors(ab) + return Base.promote_op(*, _eltype(a), _eltype(b)) end -function Base.axes(a::KroneckerBroadcasted) - ax1 = axes(arg1(a)) - ax2 = axes(arg2(a)) - return cartesianrange.(ax1 .× ax2) +function Base.axes(ab::KroneckerBroadcasted) + ax1, ax2 = axes.(kroneckerfactors(ab)) + return cartesianrange.(ax1, ax2) end function Base.BroadcastStyle( - ::Type{<:KroneckerBroadcasted{A1, A2}} - ) where {StyleA1, StyleA2, A1 <: Broadcasted{StyleA1}, A2 <: Broadcasted{StyleA2}} - @assert ndims(A1) == ndims(A2) - N = ndims(A1) + ::Type{<:KroneckerBroadcasted{A, B}} + ) where {StyleA1, StyleA2, A <: BC.Broadcasted{StyleA1}, B <: BC.Broadcasted{StyleA2}} + @assert ndims(A) == ndims(B) + N = ndims(A) return KroneckerStyle{N}(StyleA1(), StyleA2()) end # Operations that preserve the Kronecker structure. -for f in [:identity, :conj] - @eval begin - function Broadcast.broadcasted( - ::KroneckerStyle{<:Any, A1, A2}, ::typeof($f), a - ) where {A1, A2} - return broadcasted(A1, $f, arg1(a)) ⊗ broadcasted(A2, $f, arg2(a)) - end +for f in (:identity, :conj) + @eval function BC.broadcasted(::KroneckerStyle{<:Any, A, B}, ::typeof($f), ab) where {A, B} + a, b = kroneckerfactors(ab) + return BC.broadcasted(A, $f, a) ⊗ BC.broadcasted(B, $f, b) end end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index c4dd865..9cf8b33 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -1,66 +1,36 @@ -using DiagonalArrays: δ -using LinearAlgebra: - LinearAlgebra, - Diagonal, - Eigen, - SVD, - det, - diag, - eigen, - eigvals, - lq, - mul!, - norm, - qr, - svd, - svdvals, - tr - -using LinearAlgebra: LinearAlgebra -function KroneckerArray(J::LinearAlgebra.UniformScaling, ax::Tuple) - return δ(eltype(J), arg1.(ax)) ⊗ δ(eltype(J), arg2.(ax)) -end +KroneckerArray(J::LinearAlgebra.UniformScaling, ax::Tuple) = + DiagonalArrays.δ(eltype(J), kroneckerfactors.(ax, 1)) ⊗ DiagonalArrays.δ(eltype(J), kroneckerfactors.(ax, 2)) + function Base.copyto!(a::KroneckerArray, J::LinearAlgebra.UniformScaling) copyto!(a, KroneckerArray(J, axes(a))) return a end -using LinearAlgebra: LinearAlgebra, pinv -function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) - return pinv(arg1(a); kwargs...) ⊗ pinv(arg2(a); kwargs...) -end - -function LinearAlgebra.diag(a::AbstractKroneckerArray) - return copy(DiagonalArrays.diagview(a)) -end +LinearAlgebra.pinv(a::KroneckerArray; kwargs...) = ⊗(LinearAlgebra.pinv.(kroneckerfactors(a); kwargs...)...) +LinearAlgebra.diag(a::AbstractKroneckerArray) = copy(DiagonalArrays.diagview(a)) +LinearAlgebra.tr(a::AbstractKroneckerArray) = *(LinearAlgebra.tr.(kroneckerfactors(a))...) +LinearAlgebra.norm(a::AbstractKroneckerArray, p::Real = 2) = *(LinearAlgebra.norm.(kroneckerfactors(a), p)...) -function Base.:*(a::AbstractKroneckerArray, b::AbstractKroneckerArray) - return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) +function Base.:*(A::AbstractKroneckerArray, B::AbstractKroneckerArray) + a, b = kroneckerfactors(A) + c, d = kroneckerfactors(B) + return (a * c) ⊗ (b * d) end function LinearAlgebra.mul!( c::AbstractKroneckerArray, a::AbstractKroneckerArray, b::AbstractKroneckerArray, α::Number, β::Number ) - iszero(β) || iszero(c) || throw( - ArgumentError( - "Can't multiply KroneckerArrays with nonzero β and nonzero destination." - ), - ) + iszero(β) || iszero(c) || + throw(ArgumentError("Can't multiply KroneckerArrays with nonzero β and nonzero destination.")) # TODO: Only perform in-place operation on the non-active argument(s). - mul!(arg1(c), arg1(a), arg1(b)) - mul!(arg2(c), arg2(a), arg2(b), α, β) + ca, cb = kroneckerfactors(c) + aa, ab = kroneckerfactors(a) + ba, bb = kroneckerfactors(b) + LinearAlgebra.mul!(ca, aa, ba) + LinearAlgebra.mul!(cb, ab, bb, α, β) return c end -using LinearAlgebra: tr -function LinearAlgebra.tr(a::AbstractKroneckerArray) - return tr(arg1(a)) * tr(arg2(a)) -end - -using LinearAlgebra: norm -function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Real = 2) - return norm(arg1(a), p) * norm(arg2(a), p) -end # Matrix functions const MATRIX_FUNCTIONS = [ @@ -96,15 +66,14 @@ const MATRIX_FUNCTIONS = [ ] for f in MATRIX_FUNCTIONS - @eval begin - function Base.$f(a::AbstractKroneckerArray) - return if isone(arg1(a)) - arg1(a) ⊗ $f(arg2(a)) - elseif isone(arg2(a)) - $f(arg1(a)) ⊗ arg2(a) - else - throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) - end + @eval function Base.$f(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + return if isone(a) + a ⊗ $f(b) + elseif isone(b) + $f(a) ⊗ b + else + throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) end end end @@ -113,51 +82,39 @@ end # than `LinearAlgebra.checksquare`, for example it compares axes and can check # that the codomain and domain are dual of each other. using DiagonalArrays: DiagonalArrays, checksquare, issquare -function DiagonalArrays.issquare(a::AbstractKroneckerArray) - return issquare(arg1(a)) && issquare(arg2(a)) +function DiagonalArrays.issquare(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + return DiagonalArrays.issquare(a) && DiagonalArrays.issquare(b) end -using LinearAlgebra: det -function LinearAlgebra.det(a::AbstractKroneckerArray) - checksquare(a) - return det(arg1(a))^size(arg2(a), 1) * det(arg2(a))^size(arg1(a), 1) +function LinearAlgebra.det(ab::AbstractKroneckerArray) + a, b = kroneckerfactors(ab) + return LinearAlgebra.det(a)^size(b, 1) * LinearAlgebra.det(b)^size(a, 1) end -function LinearAlgebra.svd(a::AbstractKroneckerArray) - F1 = svd(arg1(a)) - F2 = svd(arg2(a)) - return SVD(F1.U ⊗ F2.U, F1.S ⊗ F2.S, F1.Vt ⊗ F2.Vt) -end -function LinearAlgebra.svdvals(a::AbstractKroneckerArray) - return svdvals(arg1(a)) ⊗ svdvals(arg2(a)) -end -function LinearAlgebra.eigen(a::AbstractKroneckerArray) - F1 = eigen(arg1(a)) - F2 = eigen(arg2(a)) - return Eigen(F1.values ⊗ F2.values, F1.vectors ⊗ F2.vectors) -end -function LinearAlgebra.eigvals(a::AbstractKroneckerArray) - return eigvals(arg1(a)) ⊗ eigvals(arg2(a)) +function LinearAlgebra.svd(ab::AbstractKroneckerArray; kwargs...) + Fa, Fb = LinearAlgebra.svd.(kroneckerfactors(ab); kwargs...) + return LinearAlgebra.SVD(Fa.U ⊗ Fb.U, Fa.S ⊗ Fb.S, Fa.Vt ⊗ Fb.Vt) end +LinearAlgebra.svdvals(a::AbstractKroneckerArray) = ⊗(LinearAlgebra.svdvals.(kroneckerfactors(a))...) -struct KroneckerQ{A1, A2} - arg1::A1 - arg2::A2 -end -@inline arg1(a::KroneckerQ) = getfield(a, :arg1) -@inline arg2(a::KroneckerQ) = getfield(a, :arg2) -function Base.:*(a::KroneckerQ, b::KroneckerQ) - return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) -end -function Base.:*(a1::KroneckerQ, a2::AbstractKroneckerArray) - return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) +function LinearAlgebra.eigen(a::AbstractKroneckerArray; kwargs...) + Fa, Fb = LinearAlgebra.eigen.(kroneckerfactors(a); kwargs...) + return LinearAlgebra.Eigen(Fa.values ⊗ Fb.values, Fa.vectors ⊗ Fb.vectors) end -function Base.:*(a1::AbstractKroneckerArray, a2::KroneckerQ) - return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) -end -function Base.adjoint(a::KroneckerQ) - return KroneckerQ(arg1(a)', arg2(a)') +LinearAlgebra.eigvals(a::AbstractKroneckerArray) = ⊗(LinearAlgebra.eigvals.(kroneckerfactors(a))...) + +struct KroneckerQ{A, B} + a::A + b::B end +kroneckerfactors(q::KroneckerQ) = q.a, q.b +kroneckerfactortypes(::Type{KroneckerQ{A, B}}) where {A, B} = (A, B) + +Base.:*(a::KroneckerQ, b::KroneckerQ) = ⊗((kroneckerfactors(a) .* kroneckerfactors(b))...) +Base.:*(a::KroneckerQ, b::AbstractKroneckerArray) = ⊗((kroneckerfactors(a) .* kroneckerfactors(b))...) +Base.:*(a::AbstractKroneckerArray, b::KroneckerQ) = ⊗((kroneckerfactors(a) .* kroneckerfactors(b))...) +Base.adjoint(a::KroneckerQ) = KroneckerQ(adjoint.(kroneckerfactors(a))...) struct KroneckerQR{QQ, RR} Q::QQ @@ -166,12 +123,13 @@ end Base.iterate(F::KroneckerQR) = (F.Q, Val(:R)) Base.iterate(F::KroneckerQR, ::Val{:R}) = (F.R, Val(:done)) Base.iterate(F::KroneckerQR, ::Val{:done}) = nothing -function ⊗(a1::LinearAlgebra.QRCompactWYQ, a2::LinearAlgebra.QRCompactWYQ) - return KroneckerQ(a1, a2) + +function ⊗(a::LinearAlgebra.QRCompactWYQ, b::LinearAlgebra.QRCompactWYQ) + return KroneckerQ(a, b) end + function LinearAlgebra.qr(a::AbstractKroneckerArray) - Fa = qr(arg1(a)) - Fb = qr(arg2(a)) + Fa, Fb = LinearAlgebra.qr.(kroneckerfactors(a)) return KroneckerQR(Fa.Q ⊗ Fb.Q, Fa.R ⊗ Fb.R) end @@ -182,11 +140,12 @@ end Base.iterate(F::KroneckerLQ) = (F.L, Val(:Q)) Base.iterate(F::KroneckerLQ, ::Val{:Q}) = (F.Q, Val(:done)) Base.iterate(F::KroneckerLQ, ::Val{:done}) = nothing -function ⊗(a1::LinearAlgebra.LQPackedQ, a2::LinearAlgebra.LQPackedQ) - return KroneckerQ(a1, a2) + +function ⊗(a::LinearAlgebra.LQPackedQ, b::LinearAlgebra.LQPackedQ) + return KroneckerQ(a, b) end + function LinearAlgebra.lq(a::AbstractKroneckerArray) - Fa = lq(arg1(a)) - Fb = lq(arg2(a)) + Fa, Fb = LinearAlgebra.lq.(kroneckerfactors(a)) return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q) end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 5383130..bbb77fc 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -1,8 +1,6 @@ using MatrixAlgebraKit: MatrixAlgebraKit, AbstractAlgorithm, TruncationStrategy, - default_eig_algorithm, default_eigh_algorithm, default_lq_algorithm, - default_polar_algorithm, default_qr_algorithm, default_svd_algorithm, eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals, eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals, initialize_output, @@ -12,137 +10,99 @@ using MatrixAlgebraKit: right_null!, right_null, right_orth!, right_orth, right_polar!, right_polar, svd_compact!, svd_compact, svd_full!, svd_full, svd_trunc!, svd_trunc, svd_vals!, svd_vals, truncate - -using DiagonalArrays: DiagonalArrays, diagview -function DiagonalArrays.diagview(a::AbstractKroneckerMatrix) - return diagview(arg1(a)) ⊗ diagview(arg2(a)) -end -MatrixAlgebraKit.diagview(a::AbstractKroneckerMatrix) = diagview(a) - -struct KroneckerAlgorithm{A1, A2} <: AbstractAlgorithm - arg1::A1 - arg2::A2 -end -@inline arg1(alg::KroneckerAlgorithm) = getfield(alg, :arg1) -@inline arg2(alg::KroneckerAlgorithm) = getfield(alg, :arg2) - using MatrixAlgebraKit: - copy_input, eig_full, eig_vals, eigh_full, eigh_vals, qr_compact, qr_full, left_null, left_orth, left_polar, lq_compact, lq_full, right_null, right_orth, right_polar, svd_compact, svd_full +using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate +import MatrixAlgebraKit as MAK + +DiagonalArrays.diagview(a::AbstractKroneckerMatrix) = ⊗(DiagonalArrays.diagview.(kroneckerfactors(a))...) +MatrixAlgebraKit.diagview(a::AbstractKroneckerMatrix) = DiagonalArrays.diagview(a) + +struct KroneckerAlgorithm{A, B} <: AbstractAlgorithm + a::A + b::B +end + +kroneckerfactors(alg::KroneckerAlgorithm) = alg.a, alg.b +kroneckerfactortypes(::Type{KroneckerAlgorithm{A, B}}) where {A, B} = (A, B) -for f in [ +for f in ( :eig_full, :eigh_full, :qr_compact, :qr_full, :lq_compact, :lq_full, :left_polar, :right_polar, :svd_compact, :svd_full, - ] - @eval begin - function MatrixAlgebraKit.copy_input(::typeof($f), a::AbstractKroneckerMatrix) - return copy_input($f, arg1(a)) ⊗ copy_input($f, arg2(a)) - end - end + ) + @eval MAK.copy_input(::typeof($f), a::AbstractKroneckerMatrix) = + ⊗(MAK.copy_input.(($f,), kroneckerfactors(a))...) end -for f in [ +for f in ( :default_eig_algorithm, :default_eigh_algorithm, :default_lq_algorithm, :default_qr_algorithm, :default_polar_algorithm, :default_svd_algorithm, - ] - @eval begin - function MatrixAlgebraKit.$f( - A::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs... - ) - A1, A2 = argument_types(A) - return KroneckerAlgorithm( - $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) - ) - end + ) + @eval function MAK.$f(A::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...) + A, B = kroneckerfactortypes(A) + return KroneckerAlgorithm( + MAK.$f(A; kwargs..., kwargs1...), + MAK.$f(B; kwargs..., kwargs2...) + ) end end -for f in [ +for f in ( :eig_full, :eigh_full, :left_polar, :right_polar, :lq_compact, :lq_full, :qr_compact, :qr_full, :svd_compact, :svd_full, - ] + ) f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm - ) - return nothing - end - function MatrixAlgebraKit.$f!( - a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm - ) - a1 = $f(arg1(a), arg1(alg)) - a2 = $f(arg2(a), arg2(alg)) - return a1 .⊗ a2 - end - end + @eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing + @eval MAK.$f!(a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) = + otimes.(MAK.$f.(kroneckerfactors(a), kroneckerfactors(alg))...) end -for f in [:eig_vals, :eigh_vals, :svd_vals] +for f in (:eig_vals, :eigh_vals, :svd_vals) f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm - ) - return nothing - end - function MatrixAlgebraKit.$f!( - a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm - ) - a1 = $f(arg1(a), arg1(alg)) - a2 = $f(arg2(a), arg2(alg)) - return a1 ⊗ a2 - end + @eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing + @eval function MAK.$f!(a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) + d1 = MAK.$f(kroneckerfactors(a, 1), kroneckerfactors(alg, 1)) + d2 = MAK.$f(kroneckerfactors(a, 2), kroneckerfactors(alg, 2)) + return d1 ⊗ d2 end end -for f in [:left_orth, :right_orth] +for f in (:left_orth, :right_orth) f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) - return nothing - end - function MatrixAlgebraKit.$f!( - a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... - ) - a1 = $f(arg1(a); kwargs..., kwargs1...) - a2 = $f(arg2(a); kwargs..., kwargs2...) - return a1 .⊗ a2 - end + @eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = + nothing + @eval function MAK.$f!(a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) + a1 = MAK.$f(kroneckerfactors(a, 1); kwargs..., kwargs1...) + a2 = MAK.$f(kroneckerfactors(a, 2); kwargs..., kwargs2...) + return a1 .⊗ a2 end end for f in [:left_null, :right_null] f! = Symbol(f, :!) - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::AbstractKroneckerMatrix) - return nothing - end - function MatrixAlgebraKit.$f!( - a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... - ) - a1 = $f(arg1(a); kwargs..., kwargs1...) - a2 = $f(arg2(a); kwargs..., kwargs2...) - return a1 ⊗ a2 - end + @eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = + nothing + @eval function MAK.$f!(a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) + a1 = MAK.$f(kroneckerfactors(a, 1); kwargs..., kwargs1...) + a2 = MAK.$f(kroneckerfactors(a, 2); kwargs..., kwargs2...) + return a1 ⊗ a2 end end # Truncation -using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate struct KroneckerTruncationStrategy{T <: TruncationStrategy} <: TruncationStrategy strategy::T @@ -158,24 +118,24 @@ axis(a) = only(axes(a)) # Convert indices determined with a generic call to `findtruncated` to indices # more suited for a KroneckerVector. function to_truncated_indices(values::OnesKroneckerVector, I) - prods = cartesianproduct(axis(values))[I] - I_id = only(to_indices(arg1(values), (:,))) - I_data = unique(arg2.(prods)) + prods = cartesianproduct(kroneckerfactors(axis(values))...)[I] + I_id = only(to_indices(kroneckerfactors(values, 1), (:,))) + I_data = unique(kroneckerfactors.(prods, 2)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> arg2(x) == i, prods) == length(arg2(values)) + return count(x -> kroneckerfactors(x, 2) == i, prods) == length(kroneckerfactors(values, 2)) end return I_id × I_data end function to_truncated_indices(values::KroneckerOnesVector, I) #I = findtruncated(Vector(values), strategy.strategy) - prods = cartesianproduct(axis(values))[I] - I_data = unique(arg1.(prods)) + prods = cartesianproduct(kroneckerfactors(axis(values))...)[I] + I_data = unique(kroneckerfactors.(prods, 1)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> arg1(x) == i, prods) == length(arg2(values)) + return count(x -> kroneckerfactors(x, 1) == i, prods) == length(kroneckerfactors(values, 2)) end - I_id = only(to_indices(arg2(values), (:,))) + I_id = only(to_indices(kroneckerfactors(values, 2), (:,))) return I_data × I_id end # Fix ambiguity error. @@ -186,39 +146,32 @@ function to_truncated_indices(values::KroneckerVector, I) return throw(ArgumentError("Not implemented")) end -function MatrixAlgebraKit.findtruncated( +function MAK.findtruncated( values::AbstractKroneckerVector, strategy::KroneckerTruncationStrategy ) I = findtruncated(Vector(values), strategy.strategy) return to_truncated_indices(values, I) end -for f in [:eig_trunc!, :eigh_trunc!] - @eval begin - function MatrixAlgebraKit.truncate( - ::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, strategy::TruncationStrategy - ) - return truncate($f, DV, KroneckerTruncationStrategy(strategy)) - end - function MatrixAlgebraKit.truncate( - ::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy - ) - I = findtruncated(diagview(D), strategy) - return (D[I, I], V[(:) × (:), I]), I - end +for f in (:eig_trunc!, :eigh_trunc!) + @eval function MAK.truncate( + ::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, strategy::TruncationStrategy + ) + return MAK.truncate($f, DV, KroneckerTruncationStrategy(strategy)) + end + @eval function MAK.truncate( + ::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy + ) + I = MAK.findtruncated(MAK.diagview(D), strategy) + return (D[I, I], V[(:) × (:), I]), I end end -function MatrixAlgebraKit.truncate( - f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, strategy::TruncationStrategy - ) - return truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy)) -end -function MatrixAlgebraKit.truncate( - ::typeof(svd_trunc!), - (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix}, - strategy::KroneckerTruncationStrategy, +MAK.truncate(f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, strategy::TruncationStrategy) = + MAK.truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy)) +function MAK.truncate( + ::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy, ) - I = findtruncated(diagview(S), strategy) + I = MAK.findtruncated(MAK.diagview(S), strategy) return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]), I end From ab9728149563c38b961fcbb2ae374167a3cc75e5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Nov 2025 17:26:13 -0500 Subject: [PATCH 3/8] update basic tests accordingly --- test/test_basics.jl | 109 ++++++++++++++++++++++---------------------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index 6aa1068..b4b2422 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,8 +5,8 @@ using DiagonalArrays: diagonal using GPUArraysCore: @allowscalar using JLArrays: JLArray using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle, - CartesianProductUnitRange, CartesianProductVector, ⊗, ×, arg1, arg2, cartesianproduct, - cartesianrange, kron_nd, unproduct + CartesianProductUnitRange, CartesianProductVector, ⊗, ×, kroneckerfactors, kroneckerfactortypes, + cartesianproduct, cartesianrange, kron_nd, unproduct using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset @@ -22,8 +22,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test r === @constinferred(cartesianrange(2 × 3)) === @constinferred(cartesianrange(Base.OneTo(2), Base.OneTo(3))) === - @constinferred(cartesianrange(Base.OneTo(2) × Base.OneTo(3))) - @test @constinferred(cartesianproduct(r)) === Base.OneTo(2) × Base.OneTo(3) + @constinferred(Base.OneTo(2) × Base.OneTo(3)) @test unproduct(r) === Base.OneTo(6) @test length(r) == 6 @test first(r) == 1 @@ -35,28 +34,27 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test r[2 × 2] == 5 @test r[2 × 3] == 6 - @test sprint(show, "text/plain", cartesianrange(2 × 3)) == - "Base.OneTo(2) × Base.OneTo(3)\nBase.OneTo(6)" - @test sprint(show, cartesianrange(2 × 3)) == "Base.OneTo(6)" + @test sprint(show, cartesianrange(2, 3)) == "(Base.OneTo(2) × Base.OneTo(3))" + @test sprint(show, cartesianrange(2, 3, 2:7)) == "cartesianrange(Base.OneTo(2), Base.OneTo(3), 2:7)" # CartesianProductUnitRange axes - r = cartesianrange((2:3) × (3:4), 2:5) - @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + r = cartesianrange(2:3, 3:4, 2:5) + @test axes(r, 1) ≡ cartesianrange(2, 2) # CartesianProductUnitRange getindex - r1 = cartesianrange((2:4) × (3:5), 2:10) - r2 = cartesianrange((2:3) × (2:3), 2:5) - @test r1[r2] ≡ cartesianrange((3:4) × (4:5), 3:6) + r1 = cartesianrange(2:4, 3:5, 2:10) + r2 = cartesianrange(2:3, 2:3, 2:5) + @test r1[r2] ≡ cartesianrange(3:4, 4:5, 3:6) - @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + @test axes(r, 1) ≡ cartesianrange(2, 2) # CartesianProductVector axes - r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9]) - @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + r = cartesianproduct([2, 4], [3, 5], [3, 5, 7, 9]) + @test axes(r) ≡ (cartesianrange(2, 2),) r = @constinferred(cartesianrange(2 × 3, 2:7)) - @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) - @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) + @test r === cartesianrange(Base.OneTo(2), Base.OneTo(3), 2:7) + @test axes(r, 1) === Base.OneTo(2) × Base.OneTo(3) @test unproduct(r) === 2:7 @test length(r) == 6 @test first(r) == 2 @@ -80,23 +78,26 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) a = @constinferred(randn(rng, elt, 2, 2) ⊗ randn(rng, elt, 3, 3)) b = @constinferred(randn(rng, elt, 2, 2) ⊗ randn(rng, elt, 3, 3)) - c = @constinferred(a.arg1 ⊗ b.arg2) - @test a isa KroneckerArray{elt, 2, typeof(a.arg1), typeof(a.arg2)} + c = @constinferred(kroneckerfactors(a, 1) ⊗ kroneckerfactors(b, 2)) + @test a isa KroneckerArray{elt, 2, kroneckerfactortypes(a)...} @test similar(typeof(a), (2, 3)) isa Matrix{elt} @test size(similar(typeof(a), (2, 3))) == (2, 3) @test isreal(a) == (elt <: Real) - @test a[1 × 1, 1 × 1] == a.arg1[1, 1] * a.arg2[1, 1] - @test a[1 × 3, 2 × 1] == a.arg1[1, 2] * a.arg2[3, 1] - @test a[1 × (2:3), 2 × 1] == a.arg1[1, 2] * a.arg2[2:3, 1] - @test a[1 × :, (:) × 1] == a.arg1[1, :] ⊗ a.arg2[:, 1] - @test a[(1:2) × (2:3), (1:2) × (2:3)] == a.arg1[1:2, 1:2] ⊗ a.arg2[2:3, 2:3] + aa, ab = kroneckerfactors(a) + for i in 1:2, j in 1:3, k in 1:2, l in 1:3 + @test a[i × j, k × l] == aa[i, k] * ab[j, l] + end + @test a[1 × (2:3), 2 × 1] == aa[1, 2] * ab[2:3, 1] + @test a[1 × :, (:) × 1] == aa[1, :] ⊗ ab[:, 1] + @test a[(1:2) × (2:3), (1:2) × (2:3)] == aa[1:2, 1:2] ⊗ ab[2:3, 2:3] v = randn(elt, 2) ⊗ randn(elt, 3) - @test v[1 × 1] == v.arg1[1] * v.arg2[1] - @test v[1 × 3] == v.arg1[1] * v.arg2[3] - @test v[(1:2) × 3] == v.arg1[1:2] * v.arg2[3] - @test v[(1:2) × (2:3)] == v.arg1[1:2] ⊗ v.arg2[2:3] + va, vb = kroneckerfactors(v) + @test v[1 × 1] == va[1] * vb[1] + @test v[1 × 3] == va[1] * vb[3] + @test v[(1:2) × 3] == va[1:2] * vb[3] + @test v[(1:2) × (2:3)] == va[1:2] ⊗ vb[2:3] @test eltype(a) === elt - @test collect(a) == kron(collect(a.arg1), collect(a.arg2)) + @test collect(a) == kron(collect(aa), collect(ab)) @test size(a) == (6, 6) @test collect(a * b) ≈ collect(a) * collect(b) @test collect(-a) == -collect(a) @@ -116,14 +117,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) # Views a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) b = @constinferred(view(a, (1:2) × (2:3), (1:2) × (2:3))) - @test arg1(b) === view(arg1(a), 1:2, 1:2) - @test arg1(b) == arg1(a)[1:2, 1:2] - @test arg2(b) === view(arg2(a), 2:3, 2:3) - @test arg2(b) == arg2(a)[2:3, 2:3] + @test kroneckerfactors(b, 1) === view(kroneckerfactors(a, 1), 1:2, 1:2) + @test kroneckerfactors(b, 1) == kroneckerfactors(a, 1)[1:2, 1:2] + @test kroneckerfactors(b, 2) === view(kroneckerfactors(a, 2), 2:3, 2:3) + @test kroneckerfactors(b, 2) == kroneckerfactors(a, 2)[2:3, 2:3] # Broadcasting a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - style = KroneckerStyle(BroadcastStyle(typeof(a.arg1)), BroadcastStyle(typeof(a.arg2))) + style = KroneckerStyle(BroadcastStyle.(kroneckerfactortypes(a))...) @test BroadcastStyle(typeof(a)) === style @test_throws "not supported" sin.(a) a′ = similar(a) @@ -133,7 +134,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test collect(a′) ≈ 2 * collect(a) bc = broadcasted(+, a, a) @test bc.style === style - @test similar(bc, elt) isa KroneckerArray{elt, 2, typeof(a.arg1), typeof(a.arg2)} + @test similar(bc, elt) isa KroneckerArray{elt, 2, kroneckerfactortypes(a)...} @test collect(copy(bc)) ≈ 2 * collect(a) bc = broadcasted(*, 2, a) @test bc.style === style @@ -182,37 +183,38 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) # permutedims a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) @test permutedims(a, (2, 1, 3)) == - permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + permutedims(kroneckerfactors(a, 1), (2, 1, 3)) ⊗ permutedims(kroneckerfactors(a, 2), (2, 1, 3)) # permutedims! a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) b = similar(a) permutedims!(b, a, (2, 1, 3)) - @test b == permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + @test b == permutedims(kroneckerfactors(a, 1), (2, 1, 3)) ⊗ permutedims(kroneckerfactors(a, 2), (2, 1, 3)) # Adapt a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) a′ = adapt(JLArray, a) @test a′ isa KroneckerArray{elt, 2, JLArray{elt, 2}, JLArray{elt, 2}} - @test a′.arg1 isa JLArray{elt, 2} - @test a′.arg2 isa JLArray{elt, 2} - @test Array(a′.arg1) == a.arg1 - @test Array(a′.arg2) == a.arg2 + @test kroneckerfactors(a′, 1) isa JLArray{elt, 2} + @test kroneckerfactors(a′, 2) isa JLArray{elt, 2} + @test Array(kroneckerfactors(a′, 1)) == kroneckerfactors(a, 1) + @test Array(kroneckerfactors(a′, 2)) == kroneckerfactors(a, 2) a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) - @test collect(a) ≈ kron_nd(a.arg1, a.arg2) - @test a[1 × 1, 1 × 1, 1 × 1] == a.arg1[1, 1, 1] * a.arg2[1, 1, 1] - @test a[1 × 3, 2 × 1, 2 × 2] == a.arg1[1, 2, 2] * a.arg2[3, 1, 2] + @test collect(a) ≈ kron_nd(kroneckerfactors(a)...) + for i in 1:2, j in 1:3, k in 1:2, l in 1:3, m in 1:2, n in 1:3 + @test a[i × j, k × l, m × n] == kroneckerfactors(a, 1)[i, k, m] * kroneckerfactors(a, 2)[j, l, n] + end @test collect(a + a) ≈ 2 * collect(a) @test collect(2a) ≈ 2 * collect(a) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c = arg1(a) ⊗ arg2(b) + c = kroneckerfactors(a, 1) ⊗ kroneckerfactors(b, 2) U, S, V = svd(a) @test collect(U * diagonal(S) * V') ≈ collect(a) - @test arg1(svdvals(a)) ≈ arg1(S) - @test arg2(svdvals(a)) ≈ arg2(S) + @test kroneckerfactors(svdvals(a), 1) ≈ kroneckerfactors(S, 1) + @test kroneckerfactors(svdvals(a), 2) ≈ kroneckerfactors(S, 2) @test sort(collect(S); rev = true) ≈ svdvals(collect(a)) @test collect(U'U) ≈ I @test collect(V * V') ≈ I @@ -232,9 +234,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) for f in KroneckerArrays.MATRIX_FUNCTIONS - @eval begin - @test_throws ArgumentError $f($a) - end + @eval @test_throws ArgumentError $f($a) end # isapprox @@ -276,8 +276,9 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) # KroneckerArrays.dist_kronecker rng = StableRNG(123) - a = randn(rng, (100, 100)) ⊗ randn(rng, (100, 100)) - b = (arg1(a) + randn(rng, size(arg1(a))) / 10) ⊗ - (arg2(a) + randn(rng, size(arg2(a))) / 10) - @test KroneckerArrays.dist_kronecker(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2 + a = randn(rng, (100, 100)) + b = randn(rng, (100, 100)) + ab = a ⊗ b + ab′ = (a + randn(rng, size(a)) / 10) ⊗ (b + randn(rng, size(b)) / 10) + @test KroneckerArrays.dist_kronecker(ab, ab′) ≈ norm(collect(ab) - collect(ab′)) rtol = 1.0e-2 end From f66ec9aa40deec1f5dbe7f662e7690fe88b1fee6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Nov 2025 19:37:40 -0500 Subject: [PATCH 4/8] update blocksparsearray tests accordingly --- test/test_blocksparsearrays.jl | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 3cf1d8d..29dfd24 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -4,7 +4,7 @@ using BlockSparseArrays: BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype, eachblockaxis using DiagonalArrays: DeltaMatrix, δ using JLArrays: JLArray -using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2, cartesianrange +using KroneckerArrays: KroneckerArray, ⊗, ×, kroneckerfactors, cartesianrange using LinearAlgebra: norm using MatrixAlgebraKit: svd_compact, svd_trunc using StableRNGs: StableRNG @@ -14,17 +14,16 @@ using TestExtras: @constinferred, @constinferred_broken elts = (Float32, Float64, ComplexF32) arrayts = (Array, JLArray) @testset "BlockSparseArraysExt, KroneckerArray blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in - arrayts, - elt in elts + arrayts, elt in elts # BlockUnitRange with CartesianProduct blocks - r = blockrange([2 × 3, 3 × 4]) - @test r[Block(1)] ≡ cartesianrange(2 × 3, 1:6) - @test r[Block(2)] ≡ cartesianrange(3 × 4, 7:18) + r = blockrange([cartesianrange(2, 3), cartesianrange(3, 4)]) + @test r[Block(1)] ≡ cartesianrange(2, 3, 1:6) + @test r[Block(2)] ≡ cartesianrange(3, 4, 7:18) @test eachblockaxis(r)[1] ≡ cartesianrange(2, 3) @test eachblockaxis(r)[2] ≡ cartesianrange(3, 4) - @test blockisequal(arg1(r), blockedrange([2, 3])) - @test blockisequal(arg2(r), blockedrange([3, 4])) + @test blockisequal(kroneckerfactors(r, 1), blockedrange([2, 3])) + @test blockisequal(kroneckerfactors(r, 2), blockedrange([3, 4])) r = blockrange([2 × 3, 3 × 4]) r′ = r[Block.([2, 1])] @@ -198,9 +197,7 @@ arrayts = (Array, JLArray) end @testset "BlockSparseArraysExt, DeltaKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in - arrayts, - elt in elts - + arrayts, elt in elts dev = adapt(arrayt) r = @constinferred blockrange([2 × 2, 2 × 3]) d = Dict( @@ -248,13 +245,13 @@ end I = [I1, I2] b = a[I, I] @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] - @test arg1(b[Block(1, 1)]) isa DeltaMatrix + @test kroneckerfactors(b[Block(1, 1)], 1) isa DeltaMatrix @test iszero(b[Block(2, 1)]) - @test arg1(b[Block(2, 1)]) isa DeltaMatrix + @test kroneckerfactors(b[Block(2, 1)], 1) isa DeltaMatrix @test iszero(b[Block(1, 2)]) - @test arg1(b[Block(1, 2)]) isa DeltaMatrix + @test kroneckerfactors(b[Block(1, 2)], 1) isa DeltaMatrix @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] - @test arg1(b[Block(2, 2)]) isa DeltaMatrix + @test kroneckerfactors(b[Block(2, 2)], 1) isa DeltaMatrix # Slicing r = blockrange([2 × 2, 3 × 3]) From 25eeb0a3e98c4f1e7b6aa4622445d91f93339f4b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Nov 2025 20:21:00 -0500 Subject: [PATCH 5/8] update the rest of the tests --- test/test_delta.jl | 274 ++++++++++++++-------------- test/test_matrixalgebrakit.jl | 58 +++--- test/test_matrixalgebrakit_delta.jl | 151 +++++++-------- test/test_tensoralgebra.jl | 4 +- test/test_tensorproducts.jl | 6 +- 5 files changed, 239 insertions(+), 254 deletions(-) diff --git a/test/test_delta.jl b/test/test_delta.jl index 256175a..763d368 100644 --- a/test/test_delta.jl +++ b/test/test_delta.jl @@ -3,7 +3,7 @@ using DerivableInterfaces: zero! using DiagonalArrays: δ using FillArrays: Eye, Zeros using JLArrays: JLArray, jl -using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2, cartesianrange +using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, kroneckerfactors, cartesianrange using LinearAlgebra: det, norm, pinv using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset @@ -18,183 +18,183 @@ using TestExtras: @constinferred a = Eye(2) ⊗ randn(3, 3) @test size(a) == (6, 6) - @test a + a == Eye(2) ⊗ (2 * arg2(a)) - @test 2a == Eye(2) ⊗ (2 * arg2(a)) - @test a * a == Eye(2) ⊗ (arg2(a) * arg2(a)) - @test_broken arg1(a[(:) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg1(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) - @test_broken arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) - @test_broken arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ + @test a + a == Eye(2) ⊗ (2 * kroneckerfactors(a, 2)) + @test 2a == Eye(2) ⊗ (2 * kroneckerfactors(a, 2)) + @test a * a == Eye(2) ⊗ (kroneckerfactors(a, 2) * kroneckerfactors(a, 2)) + @test_broken kroneckerfactors(a[(:) × (:), (:) × (:)], 1) ≡ Eye(2) + @test_broken kroneckerfactors(view(a, (:) × (:), (:) × (:)), 1) ≡ Eye(2) + @test_broken kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)], 1) ≡ Eye(2) + @test_broken kroneckerfactors(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:)), 1) ≡ Eye(2) + @test_broken kroneckerfactors(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)], 1) ≡ Eye(2) + @test_broken kroneckerfactors(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:)), 1) ≡ Eye(2) + @test_broken kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)], 1) ≡ Eye(2) - @test_broken arg1( - view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)) + @test_broken kroneckerfactors( + view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)), 1 ) ≡ Eye(2) - @test arg1(adapt(JLArray, a)) ≡ Eye(2) - @test arg2(adapt(JLArray, a)) == jl(arg2(a)) - @test arg2(adapt(JLArray, a)) isa JLArray - @test_broken arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) - @test_broken arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + @test kroneckerfactors(adapt(JLArray, a), 1) ≡ Eye(2) + @test kroneckerfactors(adapt(JLArray, a), 2) == jl(kroneckerfactors(a, 2)) + @test kroneckerfactors(adapt(JLArray, a), 2) isa JLArray + @test_broken kroneckerfactors(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2))), 1) ≡ Eye(3) + @test_broken kroneckerfactors(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2))), 1) ≡ Eye(3) - @test_broken arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + @test_broken kroneckerfactors(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))), 1 ≡ Eye{Float32}(3) - @test arg1(copy(a)) ≡ Eye(2) - @test arg2(copy(a)) == arg2(a) + @test kroneckerfactors(copy(a), 1) ≡ Eye(2) + @test kroneckerfactors(copy(a), 2) == kroneckerfactors(a, 2) b = similar(a) - @test arg1(copyto!(b, a)) ≡ Eye(2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg1(permutedims(a, (2, 1))) ≡ Eye(2) - @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + @test kroneckerfactors(copyto!(b, a), 1) ≡ Eye(2) + @test kroneckerfactors(copyto!(b, a), 2) == kroneckerfactors(a, 2) + @test kroneckerfactors(permutedims(a, (2, 1)), 1) ≡ Eye(2) + @test kroneckerfactors(permutedims(a, (2, 1)), 2) == permutedims(kroneckerfactors(a, 2), (2, 1)) b = similar(a) - @test arg1(permutedims!(b, a, (2, 1))) ≡ Eye(2) - @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 1) ≡ Eye(2) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 2) == permutedims(kroneckerfactors(a, 2), (2, 1)) a = randn(3, 3) ⊗ Eye(2) @test size(a) == (6, 6) - @test a + a == (2 * arg1(a)) ⊗ Eye(2) - @test 2a == (2 * arg1(a)) ⊗ Eye(2) - @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) - @test_broken arg2(a[(:) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg2(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) - @test_broken arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) - @test_broken arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) - @test_broken arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) - @test_broken arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ + @test a + a == (2 * kroneckerfactors(a, 1)) ⊗ Eye(2) + @test 2a == (2 * kroneckerfactors(a, 1)) ⊗ Eye(2) + @test a * a == (kroneckerfactors(a, 1) * kroneckerfactors(a, 1)) ⊗ Eye(2) + @test_broken kroneckerfactors(a[(:) × (:), (:) × (:)], 2) ≡ Eye(2) + @test_broken kroneckerfactors(view(a, (:) × (:), (:) × (:)), 2) ≡ Eye(2) + @test_broken kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)], 2) ≡ Eye(2) + @test_broken kroneckerfactors(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:)), 2) ≡ Eye(2) + @test_broken kroneckerfactors(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)], 2) ≡ Eye(2) + @test_broken kroneckerfactors(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:)), 2) ≡ Eye(2) + @test_broken kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)], 2) ≡ Eye(2) - @test_broken arg2( - view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)) + @test_broken kroneckerfactors( + view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)), 2 ) ≡ Eye(2) - @test arg2(adapt(JLArray, a)) ≡ Eye(2) - @test arg1(adapt(JLArray, a)) == jl(arg1(a)) - @test arg1(adapt(JLArray, a)) isa JLArray - @test_broken arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ Eye(3) - @test_broken arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + @test kroneckerfactors(adapt(JLArray, a), 2) ≡ Eye(2) + @test kroneckerfactors(adapt(JLArray, a), 1) == jl(kroneckerfactors(a, 1)) + @test kroneckerfactors(adapt(JLArray, a), 1) isa JLArray + @test_broken kroneckerfactors(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3))), 2) ≡ Eye(3) + @test_broken kroneckerfactors(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3))), 2) ≡ Eye(3) - @test_broken arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + @test_broken kroneckerfactors(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3))), 2) ≡ Eye{Float32}(3) - @test arg2(copy(a)) ≡ Eye(2) - @test arg2(copy(a)) == arg2(a) + @test kroneckerfactors(copy(a), 2) ≡ Eye(2) + @test kroneckerfactors(copy(a), 2) == kroneckerfactors(a, 2) b = similar(a) - @test arg2(copyto!(b, a)) ≡ Eye(2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg2(permutedims(a, (2, 1))) ≡ Eye(2) - @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + @test kroneckerfactors(copyto!(b, a), 2) ≡ Eye(2) + @test kroneckerfactors(copyto!(b, a), 2) == kroneckerfactors(a, 2) + @test kroneckerfactors(permutedims(a, (2, 1)), 2) ≡ Eye(2) + @test kroneckerfactors(permutedims(a, (2, 1)), 1) == permutedims(kroneckerfactors(a, 1), (2, 1)) b = similar(a) - @test arg2(permutedims!(b, a, (2, 1))) ≡ Eye(2) - @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 2) ≡ Eye(2) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 1) == permutedims(kroneckerfactors(a, 1), (2, 1)) a = δ(2, 2) ⊗ randn(3, 3) @test size(a) == (6, 6) - @test a + a == δ(2, 2) ⊗ (2 * arg2(a)) - @test 2a == δ(2, 2) ⊗ (2 * arg2(a)) - @test a * a == δ(2, 2) ⊗ (arg2(a) * arg2(a)) - @test arg1(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) - @test arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) - @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + @test a + a == δ(2, 2) ⊗ (2 * kroneckerfactors(a, 2)) + @test 2a == δ(2, 2) ⊗ (2 * kroneckerfactors(a, 2)) + @test a * a == δ(2, 2) ⊗ (kroneckerfactors(a, 2) * kroneckerfactors(a, 2)) + @test kroneckerfactors(a[(:) × (:), (:) × (:)], 1) ≡ δ(2, 2) + @test kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)], 1) ≡ δ(2, 2) + @test kroneckerfactors(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:)), 1) ≡ δ(2, 2) + @test kroneckerfactors(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)], 1) ≡ δ(2, 2) + @test kroneckerfactors(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:)), 1) ≡ δ(2, 2) + @test kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)], 1) ≡ δ(2, 2) + @test kroneckerfactors(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)), 1) ≡ δ(2, 2) - @test arg1(adapt(JLArray, a)) ≡ δ(2, 2) - @test arg2(adapt(JLArray, a)) == jl(arg2(a)) - @test arg2(adapt(JLArray, a)) isa JLArray - @test_broken arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) - @test_broken arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + @test kroneckerfactors(adapt(JLArray, a), 1) ≡ δ(2, 2) + @test kroneckerfactors(adapt(JLArray, a), 2) == jl(kroneckerfactors(a, 2)) + @test kroneckerfactors(adapt(JLArray, a), 2) isa JLArray + @test_broken kroneckerfactors(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2))), 1) ≡ δ(3, 3) + @test_broken kroneckerfactors(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2))), 1) ≡ δ(3, 3) - @test_broken arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + @test_broken kroneckerfactors(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2))), 1) ≡ δ(Float32, 3, 3) - @test arg1(copy(a)) ≡ δ(2, 2) - @test arg2(copy(a)) == arg2(a) + @test kroneckerfactors(copy(a), 1) ≡ δ(2, 2) + @test kroneckerfactors(copy(a), 2) == kroneckerfactors(a, 2) b = similar(a) - @test arg1(copyto!(b, a)) ≡ δ(2, 2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg1(permutedims(a, (2, 1))) ≡ δ(2, 2) - @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + @test kroneckerfactors(copyto!(b, a), 1) ≡ δ(2, 2) + @test kroneckerfactors(copyto!(b, a), 2) == kroneckerfactors(a, 2) + @test kroneckerfactors(permutedims(a, (2, 1)), 1) ≡ δ(2, 2) + @test kroneckerfactors(permutedims(a, (2, 1)), 2) == permutedims(kroneckerfactors(a, 2), (2, 1)) b = similar(a) - @test arg1(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) - @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 1) ≡ δ(2, 2) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 2) == permutedims(kroneckerfactors(a, 2), (2, 1)) a = randn(3, 3) ⊗ δ(2, 2) @test size(a) == (6, 6) - @test a + a == (2 * arg1(a)) ⊗ δ(2, 2) - @test 2a == (2 * arg1(a)) ⊗ δ(2, 2) - @test a * a == (arg1(a) * arg1(a)) ⊗ δ(2, 2) - @test arg2(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, (:) × (:), (:) × (:))) ≡ δ(2, 2) - @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) - @test arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) - @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) - @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + @test a + a == (2 * kroneckerfactors(a, 1)) ⊗ δ(2, 2) + @test 2a == (2 * kroneckerfactors(a, 1)) ⊗ δ(2, 2) + @test a * a == (kroneckerfactors(a, 1) * kroneckerfactors(a, 1)) ⊗ δ(2, 2) + @test kroneckerfactors(a[(:) × (:), (:) × (:)], 2) ≡ δ(2, 2) + @test kroneckerfactors(view(a, (:) × (:), (:) × (:)), 2) ≡ δ(2, 2) + @test kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)], 2) ≡ δ(2, 2) + @test kroneckerfactors(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:)), 2) ≡ δ(2, 2) + @test kroneckerfactors(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)], 2) ≡ δ(2, 2) + @test kroneckerfactors(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:)), 2) ≡ δ(2, 2) + @test kroneckerfactors(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)], 2) ≡ δ(2, 2) + @test kroneckerfactors(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)), 2) ≡ δ(2, 2) - @test arg2(adapt(JLArray, a)) ≡ δ(2, 2) - @test arg1(adapt(JLArray, a)) == jl(arg1(a)) - @test arg1(adapt(JLArray, a)) isa JLArray - @test_broken arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ δ(3, 3) - @test_broken arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + @test kroneckerfactors(adapt(JLArray, a), 2) ≡ δ(2, 2) + @test kroneckerfactors(adapt(JLArray, a), 1) == jl(kroneckerfactors(a, 1)) + @test kroneckerfactors(adapt(JLArray, a), 1) isa JLArray + @test_broken kroneckerfactors(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3))), 2) ≡ δ(3, 3) + @test_broken kroneckerfactors(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3))), 2) ≡ δ(3, 3) - @test_broken arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + @test_broken kroneckerfactors(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3))), 2) ≡ δ(Float32, (3, 3)) - @test arg2(copy(a)) ≡ δ(2, 2) - @test arg2(copy(a)) == arg2(a) + @test kroneckerfactors(copy(a), 2) ≡ δ(2, 2) + @test kroneckerfactors(copy(a), 2) == kroneckerfactors(a, 2) b = similar(a) - @test arg2(copyto!(b, a)) ≡ δ(2, 2) - @test arg2(copyto!(b, a)) == arg2(a) - @test arg2(permutedims(a, (2, 1))) ≡ δ(2, 2) - @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + @test kroneckerfactors(copyto!(b, a), 2) ≡ δ(2, 2) + @test kroneckerfactors(copyto!(b, a), 2) == kroneckerfactors(a, 2) + @test kroneckerfactors(permutedims(a, (2, 1)), 2) ≡ δ(2, 2) + @test kroneckerfactors(permutedims(a, (2, 1)), 1) == permutedims(kroneckerfactors(a, 1), (2, 1)) b = similar(a) - @test arg2(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) - @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 2) ≡ δ(2, 2) + @test kroneckerfactors(permutedims!(b, a, (2, 1)), 1) == permutedims(kroneckerfactors(a, 1), (2, 1)) # Views a = @constinferred(Eye(2) ⊗ randn(3, 3)) b = @constinferred(view(a, (:) × (2:3), (:) × (2:3))) - @test_broken arg1(b) ≡ Eye(2) - @test arg2(b) ≡ view(arg2(a), 2:3, 2:3) - @test arg2(b) == arg2(a)[2:3, 2:3] + @test_broken kroneckerfactors(b, 1) ≡ Eye(2) + @test kroneckerfactors(b, 2) ≡ view(kroneckerfactors(a, 2), 2:3, 2:3) + @test kroneckerfactors(b, 2) == kroneckerfactors(a, 2)[2:3, 2:3] a = randn(3, 3) ⊗ Eye(2) @test size(a) == (6, 6) - @test a + a == (2arg1(a)) ⊗ Eye(2) - @test 2a == (2arg1(a)) ⊗ Eye(2) - @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) + @test a + a == (2kroneckerfactors(a, 1)) ⊗ Eye(2) + @test 2a == (2kroneckerfactors(a, 1)) ⊗ Eye(2) + @test a * a == (kroneckerfactors(a, 1) * kroneckerfactors(a, 1)) ⊗ Eye(2) # Views a = @constinferred(randn(3, 3) ⊗ Eye(2)) b = @constinferred(view(a, (2:3) × (:), (2:3) × (:))) - @test arg1(b) ≡ view(arg1(a), 2:3, 2:3) - @test arg1(b) == arg1(a)[2:3, 2:3] - @test_broken arg2(b) ≡ Eye(2) + @test kroneckerfactors(b, 1) ≡ view(kroneckerfactors(a, 1), 2:3, 2:3) + @test kroneckerfactors(b, 1) == kroneckerfactors(a, 1)[2:3, 2:3] + @test_broken kroneckerfactors(b, 2) ≡ Eye(2) # similar a = Eye(2) ⊗ randn(3, 3) a′ = similar(a) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg1(a′) ≡ arg1(a) + @test kroneckerfactors(a′, 1) ≡ kroneckerfactors(a, 1) a = Eye(2) ⊗ randn(3, 3) a′ = similar(a, eltype(a)) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg1(a′) ≡ arg1(a) + @test kroneckerfactors(a′, 1) ≡ kroneckerfactors(a, 1) a = Eye(2) ⊗ randn(3, 3) a′ = similar(a, axes(a)) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg1(a′) ≡ arg1(a) + @test kroneckerfactors(a′, 1) ≡ kroneckerfactors(a, 1) a = Eye(2) ⊗ randn(3, 3) a′ = similar(a, eltype(a), axes(a)) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg1(a′) ≡ arg1(a) + @test kroneckerfactors(a′, 1) ≡ kroneckerfactors(a, 1) @test_broken similar(typeof(a), axes(a)) @@ -202,37 +202,37 @@ using TestExtras: @constinferred a′ = similar(a, Float32) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{Float32, ndims(a)} - @test_broken arg1(a′) ≡ Eye{Float32}(2) + @test_broken kroneckerfactors(a′, 1) ≡ Eye{Float32}(2) a = Eye(2) ⊗ randn(3, 3) a′ = similar(a, Float32, axes(a)) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{Float32, ndims(a)} - @test_broken arg1(a′) ≡ Eye{Float32}(2) + @test_broken kroneckerfactors(a′, 1) ≡ Eye{Float32}(2) a = randn(3, 3) ⊗ Eye(2) a′ = similar(a) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg2(a′) ≡ arg2(a) + @test kroneckerfactors(a′, 2) ≡ kroneckerfactors(a, 2) a = randn(3, 3) ⊗ Eye(2) a′ = similar(a, eltype(a)) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg2(a′) ≡ arg2(a) + @test kroneckerfactors(a′, 2) ≡ kroneckerfactors(a, 2) a = randn(3, 3) ⊗ Eye(2) a′ = similar(a, axes(a)) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg2(a′) ≡ arg2(a) + @test kroneckerfactors(a′, 2) ≡ kroneckerfactors(a, 2) a = randn(3, 3) ⊗ Eye(2) a′ = similar(a, eltype(a), axes(a)) @test size(a′) == (6, 6) @test a′ isa KroneckerArray{eltype(a), ndims(a)} - @test arg2(a′) ≡ arg2(a) + @test kroneckerfactors(a′, 2) ≡ kroneckerfactors(a, 2) @test_broken similar(typeof(a), axes(a)) @@ -242,7 +242,7 @@ using TestExtras: @constinferred @test a′ isa KroneckerArray{Float32, ndims(a)} # This is broken because of: # https://github.com/JuliaArrays/FillArrays.jl/issues/415 - @test_broken arg2(a′) ≡ Eye{Float32}(2) + @test_broken kroneckerfactors(a′, 2) ≡ Eye{Float32}(2) a = randn(3, 3) ⊗ Eye(2) a′ = similar(a, Float32, axes(a)) @@ -313,17 +313,17 @@ using TestExtras: @constinferred ## @eval begin ## fa = $f($a) ## @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) - ## @test arg1(fa) isa Eye + ## @test kroneckerfactors(fa) isa Eye ## end ## end fa = inv(a) @test collect(fa) ≈ inv(collect(a)) - @test arg1(fa) isa Eye + @test kroneckerfactors(fa, 1) isa Eye fa = pinv(a) @test collect(fa) ≈ pinv(collect(a)) - @test_broken arg1(fa) isa Eye + @test_broken kroneckerfactors(fa, 1) isa Eye @test det(a) ≈ det(collect(a)) @@ -334,17 +334,17 @@ using TestExtras: @constinferred ## @eval begin ## fa = $f($a) ## @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) - ## @test arg2(fa) isa Eye + ## @test kroneckerfactors(fa) isa Eye ## end ## end fa = inv(a) @test collect(fa) ≈ inv(collect(a)) - @test arg2(fa) isa Eye + @test kroneckerfactors(fa, 2) isa Eye fa = pinv(a) @test collect(fa) ≈ pinv(collect(a)) - @test_broken arg2(fa) isa Eye + @test_broken kroneckerfactors(fa, 2) isa Eye @test det(a) ≈ det(collect(a)) @@ -352,39 +352,39 @@ using TestExtras: @constinferred a = Eye(2) ⊗ Eye(2) for f in MATRIX_FUNCTIONS @eval begin - @test $f($a) == arg1($a) ⊗ $f(arg2($a)) + @test $f($a) == kroneckerfactors($a, 1) ⊗ $f(kroneckerfactors($a, 2)) end end fa = inv(a) @test fa == a - @test arg1(fa) isa Eye - @test arg2(fa) isa Eye + @test kroneckerfactors(fa, 1) isa Eye + @test kroneckerfactors(fa, 2) isa Eye fa = pinv(a) @test fa == a - @test_broken arg1(fa) isa Eye - @test_broken arg2(fa) isa Eye + @test_broken kroneckerfactors(fa, 1) isa Eye + @test_broken kroneckerfactors(fa, 2) isa Eye @test det(a) ≈ det(collect(a)) ≈ 1 # permutedims a = Eye(2, 2) ⊗ randn(3, 3) - @test permutedims(a, (2, 1)) == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + @test permutedims(a, (2, 1)) == Eye(2, 2) ⊗ permutedims(kroneckerfactors(a, 2), (2, 1)) a = randn(2, 2) ⊗ Eye(3, 3) - @test permutedims(a, (2, 1)) == permutedims(arg1(a), (2, 1)) ⊗ Eye(3, 3) + @test permutedims(a, (2, 1)) == permutedims(kroneckerfactors(a, 1), (2, 1)) ⊗ Eye(3, 3) # permutedims! a = Eye(2, 2) ⊗ randn(3, 3) b = similar(a) permutedims!(b, a, (2, 1)) - @test b == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + @test b == Eye(2, 2) ⊗ permutedims(kroneckerfactors(a, 2), (2, 1)) a = randn(3, 3) ⊗ Eye(2, 2) b = similar(a) permutedims!(b, a, (2, 1)) - @test b == permutedims(arg1(a), (2, 1)) ⊗ Eye(2, 2) + @test b == permutedims(kroneckerfactors(a, 1), (2, 1)) ⊗ Eye(2, 2) end @testset "FillArrays.Zeros" begin diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index f62c492..5be7fcd 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,4 +1,4 @@ -using KroneckerArrays: ⊗, arg1, arg2 +using KroneckerArrays: ⊗, kroneckerfactors using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc, eigh_vals, left_null, left_orth, left_polar, lq_compact, lq_full, qr_compact, @@ -16,8 +16,8 @@ herm(a) = parent(hermitianpart(a)) d, v = eig_full(a) av = a * v vd = v * d - @test arg1(av) ≈ arg1(vd) - @test arg2(av) ≈ arg2(vd) + @test kroneckerfactors(av, 1) ≈ kroneckerfactors(vd, 1) + @test kroneckerfactors(av, 2) ≈ kroneckerfactors(vd, 2) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) @test_throws ArgumentError eig_trunc(a) @@ -25,15 +25,15 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) d = eig_vals(a) d′ = diag(eig_full(a)[1]) - @test arg1(d) ≈ arg1(d′) - @test arg2(d) ≈ arg2(d′) + @test kroneckerfactors(d, 1) ≈ kroneckerfactors(d′, 1) + @test kroneckerfactors(d, 2) ≈ kroneckerfactors(d′, 2) a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) d, v = eigh_full(a) av = a * v vd = v * d - @test arg1(av) ≈ arg1(vd) - @test arg2(av) ≈ arg2(vd) + @test kroneckerfactors(av, 1) ≈ kroneckerfactors(vd, 1) + @test kroneckerfactors(av, 2) ≈ kroneckerfactors(vd, 2) @test eltype(d) === real(elt) @test eltype(v) === elt @@ -48,29 +48,29 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = qr_compact(a) uc = u * c - @test arg1(uc) ≈ arg1(a) - @test arg2(uc) ≈ arg2(a) + @test kroneckerfactors(uc, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(uc, 2) ≈ kroneckerfactors(a, 2) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = qr_full(a) uc = u * c - @test arg1(uc) ≈ arg1(a) - @test arg2(uc) ≈ arg2(a) + @test kroneckerfactors(uc, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(uc, 2) ≈ kroneckerfactors(a, 2) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = lq_compact(a) cu = c * u - @test arg1(cu) ≈ arg1(a) - @test arg2(cu) ≈ arg2(a) + @test kroneckerfactors(cu, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(cu, 2) ≈ kroneckerfactors(a, 2) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = lq_full(a) cu = c * u - @test arg1(cu) ≈ arg1(a) - @test arg2(cu) ≈ arg2(a) + @test kroneckerfactors(cu, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(cu, 2) ≈ kroneckerfactors(a, 2) @test collect(u * u') ≈ I a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) @@ -84,36 +84,36 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = left_orth(a) uc = u * c - @test arg1(uc) ≈ arg1(a) - @test arg2(uc) ≈ arg2(a) + @test kroneckerfactors(uc, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(uc, 2) ≈ kroneckerfactors(a, 2) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = right_orth(a) cu = c * u - @test arg1(cu) ≈ arg1(a) - @test arg2(cu) ≈ arg2(a) + @test kroneckerfactors(cu, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(cu, 2) ≈ kroneckerfactors(a, 2) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = left_polar(a) uc = u * c - @test arg1(uc) ≈ arg1(a) - @test arg2(uc) ≈ arg2(a) + @test kroneckerfactors(uc, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(uc, 2) ≈ kroneckerfactors(a, 2) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = right_polar(a) cu = c * u - @test arg1(cu) ≈ arg1(a) - @test arg2(cu) ≈ arg2(a) + @test kroneckerfactors(cu, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(cu, 2) ≈ kroneckerfactors(a, 2) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_compact(a) usv = u * s * v - @test arg1(usv) ≈ arg1(a) - @test arg2(usv) ≈ arg2(a) + @test kroneckerfactors(usv, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(usv, 2) ≈ kroneckerfactors(a, 2) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt @@ -123,8 +123,8 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_full(a) usv = u * s * v - @test arg1(usv) ≈ arg1(a) - @test arg2(usv) ≈ arg2(a) + @test kroneckerfactors(usv, 1) ≈ kroneckerfactors(a, 1) + @test kroneckerfactors(usv, 2) ≈ kroneckerfactors(a, 2) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt @@ -137,6 +137,6 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) s = svd_vals(a) s′ = diag(svd_compact(a)[2]) - @test arg1(s) ≈ arg1(s′) - @test arg2(s) ≈ arg2(s′) + @test kroneckerfactors(s, 1) ≈ kroneckerfactors(s′, 1) + @test kroneckerfactors(s, 2) ≈ kroneckerfactors(s′, 2) end diff --git a/test/test_matrixalgebrakit_delta.jl b/test/test_matrixalgebrakit_delta.jl index a693b1e..9191234 100644 --- a/test/test_matrixalgebrakit_delta.jl +++ b/test/test_matrixalgebrakit_delta.jl @@ -1,28 +1,13 @@ using FillArrays: Ones using DiagonalArrays: δ, DeltaMatrix -using KroneckerArrays: ⊗, arguments +using KroneckerArrays: ⊗, kroneckerfactors using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm using MatrixAlgebraKit: - eig_full, - eig_trunc, - eig_vals, - eigh_full, - eigh_trunc, - eigh_vals, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - qr_compact, - qr_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full, - svd_trunc, - svd_vals + eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc, eigh_vals, + left_null, left_orth, left_polar, + lq_compact, lq_full, qr_compact, qr_full, + right_null, right_orth, right_polar, + svd_compact, svd_full, svd_trunc, svd_vals using Test: @test, @test_broken, @test_throws, @testset using TestExtras: @constinferred @@ -33,60 +18,60 @@ herm(a) = parent(hermitianpart(a)) a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{complex(elt)} - @test arguments(v, 1) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(d, 1) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(v, 1) isa DeltaMatrix{complex(elt)} a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ δ(elt, 3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d - @test arguments(d, 2) isa DeltaMatrix{complex(elt)} - @test arguments(v, 2) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(d, 2) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(v, 2) isa DeltaMatrix{complex(elt)} a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{complex(elt)} - @test arguments(d, 2) isa DeltaMatrix{complex(elt)} - @test arguments(v, 1) isa DeltaMatrix{complex(elt)} - @test arguments(v, 2) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(d, 1) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(d, 2) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(v, 1) isa DeltaMatrix{complex(elt)} + @test kroneckerfactors(v, 2) isa DeltaMatrix{complex(elt)} end for elt in (Float32, ComplexF32) a = δ(elt, 3, 3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) d, v = @constinferred eigh_full($a) @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(d, 1) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 1) isa DeltaMatrix{elt} a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ δ(elt, 3, 3) d, v = @constinferred eigh_full($a) @test a * v ≈ v * d - @test arguments(d, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(d, 2) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 2) isa DeltaMatrix{elt} a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) d, v = @constinferred eigh_full($a) @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix{real(elt)} - @test arguments(d, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} - @test arguments(v, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(d, 1) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(d, 2) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(v, 2) isa DeltaMatrix{elt} end for f in (eig_trunc, eigh_trunc) a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) d, v = f(a; trunc = (; maxrank = 7)) @test a * v ≈ v * d - @test arguments(d, 1) isa DeltaMatrix - @test arguments(v, 1) isa DeltaMatrix + @test kroneckerfactors(d, 1) isa DeltaMatrix + @test kroneckerfactors(v, 1) isa DeltaMatrix @test size(d) == (6, 6) @test size(v) == (9, 6) a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) d, v = f(a; trunc = (; maxrank = 7)) @test a * v ≈ v * d - @test arguments(d, 2) isa DeltaMatrix - @test arguments(v, 2) isa DeltaMatrix + @test kroneckerfactors(d, 2) isa DeltaMatrix + @test kroneckerfactors(v, 2) isa DeltaMatrix @test size(d) == (6, 6) @test size(v) == (9, 6) @@ -99,21 +84,21 @@ herm(a) = parent(hermitianpart(a)) d = @constinferred f(a) d′ = f(Matrix(a)) @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) ≈ f(arguments(a, 2)) + @test kroneckerfactors(d, 1) isa Ones + @test kroneckerfactors(d, 2) ≈ f(kroneckerfactors(a, 2)) a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) d = @constinferred f(a) d′ = f(Matrix(a)) @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) - @test arguments(d, 2) isa Ones - @test arguments(d, 1) ≈ f(arguments(a, 1)) + @test kroneckerfactors(d, 2) isa Ones + @test kroneckerfactors(d, 1) ≈ f(kroneckerfactors(a, 1)) a = δ(3, 3) ⊗ δ(3, 3) d = @constinferred f(a) @test d == Ones(3) ⊗ Ones(3) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) isa Ones + @test kroneckerfactors(d, 1) isa Ones + @test kroneckerfactors(d, 2) isa Ones end for f in ( @@ -127,22 +112,22 @@ herm(a) = parent(hermitianpart(a)) x, y = f(a) end @test x * y ≈ a - @test arguments(x, 1) isa DeltaMatrix - @test arguments(y, 1) isa DeltaMatrix + @test kroneckerfactors(x, 1) isa DeltaMatrix + @test kroneckerfactors(y, 1) isa DeltaMatrix a = randn(3, 3) ⊗ δ(3, 3) x, y = @constinferred f($a) @test x * y ≈ a - @test arguments(x, 2) isa DeltaMatrix - @test arguments(y, 2) isa DeltaMatrix + @test kroneckerfactors(x, 2) isa DeltaMatrix + @test kroneckerfactors(y, 2) isa DeltaMatrix a = δ(3, 3) ⊗ δ(3, 3) x, y = @constinferred f($a) @test x * y ≈ a - @test arguments(x, 1) isa DeltaMatrix - @test arguments(y, 1) isa DeltaMatrix - @test arguments(x, 2) isa DeltaMatrix - @test arguments(y, 2) isa DeltaMatrix + @test kroneckerfactors(x, 1) isa DeltaMatrix + @test kroneckerfactors(y, 1) isa DeltaMatrix + @test kroneckerfactors(x, 2) isa DeltaMatrix + @test kroneckerfactors(y, 2) isa DeltaMatrix end for f in (svd_compact, svd_full) @@ -153,9 +138,9 @@ herm(a) = parent(hermitianpart(a)) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt - @test arguments(u, 1) isa DeltaMatrix{elt} - @test arguments(s, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(u, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(s, 1) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 1) isa DeltaMatrix{elt} a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) u, s, v = @constinferred f($a) @@ -163,9 +148,9 @@ herm(a) = parent(hermitianpart(a)) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt - @test arguments(u, 2) isa DeltaMatrix{elt} - @test arguments(s, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(u, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(s, 2) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 2) isa DeltaMatrix{elt} a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) u, s, v = @constinferred f($a) @@ -173,12 +158,12 @@ herm(a) = parent(hermitianpart(a)) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt - @test arguments(u, 1) isa DeltaMatrix{elt} - @test arguments(s, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} - @test arguments(u, 2) isa DeltaMatrix{elt} - @test arguments(s, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(u, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(s, 1) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(u, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(s, 2) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 2) isa DeltaMatrix{elt} end end @@ -194,9 +179,9 @@ herm(a) = parent(hermitianpart(a)) @test eltype(v) === elt u′, s′, v′ = svd_trunc(Matrix(a); trunc = (; maxrank = 6)) @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 1) isa DeltaMatrix{elt} - @test arguments(s, 1) isa DeltaMatrix{real(elt)} - @test arguments(v, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(u, 1) isa DeltaMatrix{elt} + @test kroneckerfactors(s, 1) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 1) isa DeltaMatrix{elt} @test size(u) == (9, 6) @test size(s) == (6, 6) @test size(v) == (6, 9) @@ -213,9 +198,9 @@ herm(a) = parent(hermitianpart(a)) @test eltype(v) === elt u′, s′, v′ = svd_trunc(Matrix(a); trunc = (; maxrank = 6)) @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 2) isa DeltaMatrix{elt} - @test arguments(s, 2) isa DeltaMatrix{real(elt)} - @test arguments(v, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(u, 2) isa DeltaMatrix{elt} + @test kroneckerfactors(s, 2) isa DeltaMatrix{real(elt)} + @test kroneckerfactors(v, 2) isa DeltaMatrix{elt} @test size(u) == (9, 6) @test size(s) == (6, 6) @test size(v) == (6, 9) @@ -230,8 +215,8 @@ herm(a) = parent(hermitianpart(a)) d = @constinferred svd_vals(a) d′ = svd_vals(Matrix(a)) @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) - @test arguments(d, 1) isa Ones{real(elt)} - @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) + @test kroneckerfactors(d, 1) isa Ones{real(elt)} + @test kroneckerfactors(d, 2) ≈ svd_vals(kroneckerfactors(a, 2)) end for elt in (Float32, ComplexF32) @@ -239,16 +224,16 @@ herm(a) = parent(hermitianpart(a)) d = @constinferred svd_vals(a) d′ = svd_vals(Matrix(a)) @test sort(Vector(d); by = abs) ≈ sort(d′; by = abs) - @test arguments(d, 2) isa Ones{real(elt)} - @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) + @test kroneckerfactors(d, 2) isa Ones{real(elt)} + @test kroneckerfactors(d, 1) ≈ svd_vals(kroneckerfactors(a, 1)) end for elt in (Float32, ComplexF32) a = δ(elt, 3, 3) ⊗ δ(elt, 3, 3) d = @constinferred svd_vals(a) @test d ≡ Ones{real(elt)}(3) ⊗ Ones{real(elt)}(3) - @test arguments(d, 1) isa Ones{real(elt)} - @test arguments(d, 2) isa Ones{real(elt)} + @test kroneckerfactors(d, 1) isa Ones{real(elt)} + @test kroneckerfactors(d, 2) isa Ones{real(elt)} end # left_null @@ -256,13 +241,13 @@ herm(a) = parent(hermitianpart(a)) @test_broken left_null(a) ## n = @constinferred left_null(a) ## @test norm(n' * a) ≈ 0 - ## @test arguments(n, 1) isa DeltaMatrix + ## @test kroneckerfactors(n, 1) isa DeltaMatrix a = randn(3, 3) ⊗ δ(3, 3) @test_broken left_null(a) ## n = @constinferred left_null(a) ## @test norm(n' * a) ≈ 0 - ## @test arguments(n, 2) isa DeltaMatrix + ## @test kroneckerfactors(n, 2) isa DeltaMatrix a = δ(3, 3) ⊗ δ(3, 3) @test_broken left_null(a) @@ -272,13 +257,13 @@ herm(a) = parent(hermitianpart(a)) @test_broken right_null(a) ## n = @constinferred right_null(a) ## @test norm(a * n') ≈ 0 - ## @test arguments(n, 1) isa DeltaMatrix + ## @test kroneckerfactors(n, 1) isa DeltaMatrix a = randn(3, 3) ⊗ δ(3, 3) @test_broken right_null(a) ## n = @constinferred right_null(a) ## @test norm(a * n') ≈ 0 - ## @test arguments(n, 2) isa DeltaMatrix + ## @test kroneckerfactors(n, 2) isa DeltaMatrix a = δ(3, 3) ⊗ δ(3, 3) @test_broken right_null(a) diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl index 97e02a4..79763bf 100644 --- a/test/test_tensoralgebra.jl +++ b/test/test_tensoralgebra.jl @@ -1,10 +1,10 @@ using TensorAlgebra: matricize, unmatricize -using KroneckerArrays: ⊗, arg1, arg2 +using KroneckerArrays: ⊗, kroneckerfactors using Test: @test, @testset @testset "TensorAlgebraExt" begin a = randn(2, 2, 2) ⊗ randn(3, 3, 3) m = matricize(a, (1, 2), (3,)) - @test m == matricize(arg1(a), (1, 2), (3,)) ⊗ matricize(arg2(a), (1, 2), (3,)) + @test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) ⊗ matricize(kroneckerfactors(a, 2), (1, 2), (3,)) @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a end diff --git a/test/test_tensorproducts.jl b/test/test_tensorproducts.jl index 2812966..d8ed112 100644 --- a/test/test_tensorproducts.jl +++ b/test/test_tensorproducts.jl @@ -1,4 +1,4 @@ -using KroneckerArrays: ×, arg1, arg2, cartesianrange, unproduct +using KroneckerArrays: ×, kroneckerfactors, cartesianrange, unproduct using TensorProducts: tensor_product using Test: @test, @testset @@ -7,7 +7,7 @@ using Test: @test, @testset r2 = cartesianrange(4, 5) r = tensor_product(r1, r2) @test r ≡ cartesianrange(8, 15) - @test arg1(r) ≡ Base.OneTo(8) - @test arg2(r) ≡ Base.OneTo(15) + @test kroneckerfactors(r, 1) ≡ Base.OneTo(8) + @test kroneckerfactors(r, 2) ≡ Base.OneTo(15) @test unproduct(r) ≡ Base.OneTo(120) end From ef6bed993d0581eb5e3c3676bcf615f4b9e607bf Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Nov 2025 20:25:21 -0500 Subject: [PATCH 6/8] Bump v0.3.0 --- Project.toml | 2 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index aa26928..dae7bab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.2.9" +version = "0.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/Project.toml b/docs/Project.toml index 5ce6ea9..b9139c3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" [compat] Documenter = "1" Literate = "2" -KroneckerArrays = "0.2" +KroneckerArrays = "0.3" diff --git a/examples/Project.toml b/examples/Project.toml index ccb8779..22dfcc1 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" [compat] -KroneckerArrays = "0.2" +KroneckerArrays = "0.3" diff --git a/test/Project.toml b/test/Project.toml index 6b2615d..b9d9be9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,7 @@ DiagonalArrays = "0.3.7" FillArrays = "1" GPUArraysCore = "0.2" JLArrays = "0.2, 0.3" -KroneckerArrays = "0.2" +KroneckerArrays = "0.3" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5" SafeTestsets = "0.1" From cae92d5052389ee80a37fd0c1cadea4abfaa628a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 12:10:45 -0500 Subject: [PATCH 7/8] type stability improvements --- src/matrixalgebrakit.jl | 46 +++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index bbb77fc..12df9c0 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -38,8 +38,12 @@ for f in ( :left_polar, :right_polar, :svd_compact, :svd_full, ) - @eval MAK.copy_input(::typeof($f), a::AbstractKroneckerMatrix) = - ⊗(MAK.copy_input.(($f,), kroneckerfactors(a))...) + @eval function MAK.copy_input(::typeof($f), ab::AbstractKroneckerMatrix) + a, b = kroneckerfactors(ab) + ac = MAK.copy_input($f, a) + bc = MAK.copy_input($f, b) + return ac ⊗ bc + end end for f in ( @@ -65,28 +69,33 @@ for f in ( ) f! = Symbol(f, :!) @eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing - @eval MAK.$f!(a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) = - otimes.(MAK.$f.(kroneckerfactors(a), kroneckerfactors(alg))...) + @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) + a, b = kroneckerfactors(ab) + algA, algB = kroneckerfactors(alg) + Fa = MAK.$f(a, algA) + Fb = MAK.$f(b, algB) + return Fa .⊗ Fb + end end for f in (:eig_vals, :eigh_vals, :svd_vals) f! = Symbol(f, :!) @eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing - @eval function MAK.$f!(a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) - d1 = MAK.$f(kroneckerfactors(a, 1), kroneckerfactors(alg, 1)) - d2 = MAK.$f(kroneckerfactors(a, 2), kroneckerfactors(alg, 2)) - return d1 ⊗ d2 + @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) + a, b = kroneckerfactors(ab) + algA, algB = kroneckerfactors(alg) + return MAK.$f(a, algA) ⊗ MAK.$f(b, algB) end end for f in (:left_orth, :right_orth) f! = Symbol(f, :!) - @eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = - nothing - @eval function MAK.$f!(a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) - a1 = MAK.$f(kroneckerfactors(a, 1); kwargs..., kwargs1...) - a2 = MAK.$f(kroneckerfactors(a, 2); kwargs..., kwargs2...) - return a1 .⊗ a2 + @eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = nothing + @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) + a, b = kroneckerfactors(ab) + Fa = MAK.$f(a; kwargs..., kwargs1...) + Fb = MAK.$f(b; kwargs..., kwargs2...) + return Fa .⊗ Fb end end @@ -94,10 +103,11 @@ for f in [:left_null, :right_null] f! = Symbol(f, :!) @eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = nothing - @eval function MAK.$f!(a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) - a1 = MAK.$f(kroneckerfactors(a, 1); kwargs..., kwargs1...) - a2 = MAK.$f(kroneckerfactors(a, 2); kwargs..., kwargs2...) - return a1 ⊗ a2 + @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) + a, b = kroneckerfactors(ab) + Na = MAK.$f(a; kwargs..., kwargs1...) + Nb = MAK.$f(b; kwargs..., kwargs2...) + return Na ⊗ Nb end end From 5c04b3869b56988dae10a3669d33d0c44c5c96f3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 13:10:56 -0500 Subject: [PATCH 8/8] change order to make docstrings happy --- src/KroneckerArrays.jl | 6 +++--- src/cartesianproduct.jl | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 5698101..76962ab 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -41,9 +41,9 @@ kroneckerfactortypes(T::Type) = throw(MethodError(kroneckerfactortypes, (T,))) otimes(args...) Construct an object that represents the Kronecker product of the provided `args`. -""" otimes -function otimes(a, b) end -const ⊗ = otimes # unicode alternative +""" (⊗) +function ⊗(a, b) end +const otimes = ⊗ # non-unicode alternative # Includes # -------- diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index ea1d2f0..ec711c2 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -85,11 +85,11 @@ const AnyCartesian = Union{CartesianPair, CartesianProduct, CartesianProductVect Construct an object that represents the Cartesian product of the provided `args`. By default this constructs the singular [`CartesianPair`](@ref) for unknown values, while attempting to promote to more structured types wherever possible. See also [`CartesianProduct`](@ref), [`CartesianProductVector`](@ref) and [`CartesianProductUnitRange`](@ref). -""" times +""" (×) # implement multi-argument version through a left fold -times(x) = x -times(x, y, z...) = foldl(times, (x, y, z...)) -const × = times # unicode alternative +×(x) = x +×(x, y, z...) = foldl(×, (x, y, z...)) +const times = × # non-unicode alternative # fallback definition for cartesian product ×(a, b) = CartesianPair(a, b)