diff --git a/Project.toml b/Project.toml index ed3d9f9..196b1f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.11" +version = "0.1.12" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index fd8d8e3..597872c 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -1,1326 +1,13 @@ module KroneckerArrays -using GPUArraysCore: GPUArraysCore - export ⊗, × -struct CartesianProduct{A,B} - a::A - b::B -end -arguments(a::CartesianProduct) = (a.a, a.b) -arguments(a::CartesianProduct, n::Int) = arguments(a)[n] - -function Base.show(io::IO, a::CartesianProduct) - print(io, a.a, " × ", a.b) - return nothing -end - -×(a, b) = CartesianProduct(a, b) -Base.length(a::CartesianProduct) = length(a.a) * length(a.b) -Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b] - -function Base.iterate(a::CartesianProduct, state...) - x = iterate(Iterators.product(a.a, a.b), state...) - isnothing(x) && return x - next, new_state = x - return ×(next...), new_state -end - -struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <: - AbstractUnitRange{T} - product::P - range::R -end -Base.first(r::CartesianProductUnitRange) = first(r.range) -Base.last(r::CartesianProductUnitRange) = last(r.range) - -cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product) -unproduct(r::CartesianProductUnitRange) = getfield(r, :range) - -function CartesianProductUnitRange(p::CartesianProduct) - return CartesianProductUnitRange(p, Base.OneTo(length(p))) -end -function CartesianProductUnitRange(a, b) - return CartesianProductUnitRange(a × b) -end -to_range(a::AbstractUnitRange) = a -to_range(i::Integer) = Base.OneTo(i) -cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b)) -function cartesianrange(p::CartesianProduct) - p′ = to_range(p.a) × to_range(p.b) - return cartesianrange(p′, Base.OneTo(length(p′))) -end -function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) - p′ = to_range(p.a) × to_range(p.b) - return CartesianProductUnitRange(p′, range) -end - -function Base.axes(r::CartesianProductUnitRange) - return (CartesianProductUnitRange(r.product, only(axes(r.range))),) -end - -using Base.Broadcast: DefaultArrayStyle -for f in (:+, :-) - @eval begin - function Broadcast.broadcasted( - ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer - ) - return CartesianProductUnitRange(r.product, $f.(r.range, x)) - end - function Broadcast.broadcasted( - ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange - ) - return CartesianProductUnitRange(r.product, $f.(x, r.range)) - end - end -end - -struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N} - a::A - b::B -end -function KroneckerArray(a::AbstractArray, b::AbstractArray) - if ndims(a) != ndims(b) - throw( - ArgumentError("Kronecker product requires arrays of the same number of dimensions.") - ) - end - elt = promote_type(eltype(a), eltype(b)) - return KroneckerArray(convert(AbstractArray{elt}, a), convert(AbstractArray{elt}, b)) -end -const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B} -const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B} - -using Adapt: Adapt, adapt -Adapt.adapt_structure(to, a::KroneckerArray) = adapt(to, a.a) ⊗ adapt(to, a.b) - -function Base.copy(a::KroneckerArray) - return copy(a.a) ⊗ copy(a.b) -end -function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) - copyto!(dest.a, src.a) - copyto!(dest.b, src.b) - return dest -end - -function Base.similar( - a::AbstractArray, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - return similar(a, elt, map(ax -> ax.product.a, axs)) ⊗ - similar(a, elt, map(ax -> ax.product.b, axs)) -end -function Base.similar( - a::KroneckerArray, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - return similar(a.a, elt, map(ax -> ax.product.a, axs)) ⊗ - similar(a.b, elt, map(ax -> ax.product.b, axs)) -end -function Base.similar( - arrayt::Type{<:AbstractArray}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - return similar(arrayt, map(ax -> ax.product.a, axs)) ⊗ - similar(arrayt, map(ax -> ax.product.b, axs)) -end -function Base.similar( - arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) where {A,B} - return similar(A, map(ax -> ax.product.a, axs)) ⊗ similar(B, map(ax -> ax.product.b, axs)) -end -function Base.similar( - ::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}} -) where {A,B} - return similar(promote_type(A, B), sz) -end - -function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}}) - return (t[1]..., flatten(Base.tail(t))...) -end -function flatten(t::Tuple{Tuple}) - return t[1] -end -flatten(::Tuple{}) = () -function interleave(x::Tuple, y::Tuple) - length(x) == length(y) || throw(ArgumentError("Tuples must have the same length.")) - xy = ntuple(i -> (x[i], y[i]), length(x)) - return flatten(xy) -end -# TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing: -# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 -function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N} - a′ = reshape(a, interleave(size(a), ntuple(one, N))) - b′ = reshape(b, interleave(ntuple(one, N), size(b))) - c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N))) - sz = ntuple(i -> size(a, i) * size(b, i), N) - return permutedims(reshape(c′, sz), reverse(ntuple(identity, N))) -end -kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b) -kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b) - -Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b) - -function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} - return convert(Array{T,N}, collect(a)) -end - -Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a)) - -function Base.axes(a::KroneckerArray) - return ntuple(ndims(a)) do dim - return CartesianProductUnitRange( - axes(a.a, dim) × axes(a.b, dim), Base.OneTo(size(a, dim)) - ) - end -end - -arguments(a::KroneckerArray) = (a.a, a.b) -arguments(a::KroneckerArray, n::Int) = arguments(a)[n] -argument_types(a::KroneckerArray) = argument_types(typeof(a)) -argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B) - -function Base.print_array(io::IO, a::KroneckerArray) - Base.print_array(io, a.a) - println(io, "\n ⊗") - Base.print_array(io, a.b) - return nothing -end -function Base.show(io::IO, a::KroneckerArray) - show(io, a.a) - print(io, " ⊗ ") - show(io, a.b) - return nothing -end - -⊗(a::AbstractArray, b::AbstractArray) = KroneckerArray(a, b) -⊗(a::Number, b::Number) = a * b -⊗(a::Number, b::AbstractArray) = a * b -⊗(a::AbstractArray, b::Number) = a * b - -function Base.getindex(a::KroneckerArray, i::Integer) - return a[CartesianIndices(a)[i]] -end - -# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing -# in the n-dimensional case and use it to replace the matrix and vector cases: -# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} - return error("Not implemented.") -end - -function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) - GPUArraysCore.assertscalar("getindex") - # Code logic from Kronecker.jl: - # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 - k, l = size(a.b) - return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] -end - -function Base.getindex(a::KroneckerVector, i::Integer) - GPUArraysCore.assertscalar("getindex") - k = length(a.b) - return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] -end - -## function Base.getindex(a::KroneckerVector, i::CartesianProduct) -## return a.a[i.a] ⊗ a.b[i.b] -## end -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} - return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] -end -# Fix ambigiuity error. -Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[] - -function Base.:(==)(a::KroneckerArray, b::KroneckerArray) - return a.a == b.a && a.b == b.b -end -function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...) - return isapprox(a.a, b.a; kwargs...) && isapprox(a.b, b.b; kwargs...) -end -function Base.iszero(a::KroneckerArray) - return iszero(a.a) || iszero(a.b) -end -function Base.isreal(a::KroneckerArray) - return isreal(a.a) && isreal(a.b) -end -function Base.inv(a::KroneckerArray) - return inv(a.a) ⊗ inv(a.b) -end -using LinearAlgebra: LinearAlgebra, pinv -function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) - return pinv(a.a; kwargs...) ⊗ pinv(a.b; kwargs...) -end -function Base.transpose(a::KroneckerArray) - return transpose(a.a) ⊗ transpose(a.b) -end -function Base.adjoint(a::KroneckerArray) - return a.a' ⊗ a.b' -end - -function Base.:*(a::Number, b::KroneckerArray) - return (a * b.a) ⊗ b.b -end -function Base.:*(a::KroneckerArray, b::Number) - return a.a ⊗ (a.b * b) -end -function Base.:/(a::KroneckerArray, b::Number) - return a * inv(b) -end - -function Base.:-(a::KroneckerArray) - return (-a.a) ⊗ a.b -end -for op in (:+, :-) - @eval begin - function Base.$op(a::KroneckerArray, b::KroneckerArray) - if a.b == b.b - return $op(a.a, b.a) ⊗ a.b - elseif a.a == b.a - return a.a ⊗ $op(a.b, b.b) - end - return throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end - end -end - -using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted -struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end -function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N} - return KroneckerStyle{N,a,b}() -end -function KroneckerStyle(a::AbstractArrayStyle{N}, b::AbstractArrayStyle{N}) where {N} - return KroneckerStyle{N}(a, b) -end -function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M} - return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}() -end -function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B} - return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B)) -end -function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N} - return KroneckerStyle{N}( - BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b) - ) -end -function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B} - ax_a = map(ax -> ax.product.a, axes(bc)) - ax_b = map(ax -> ax.product.b, axes(bc)) - bc_a = Broadcasted(A, nothing, (), ax_a) - bc_b = Broadcasted(B, nothing, (), ax_b) - a = similar(bc_a, elt) - b = similar(bc_b, elt) - return a ⊗ b -end -function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle}) - return throw( - ArgumentError( - "Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", - ), - ) -end - -function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) - return throw( - ArgumentError( - "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", - ), - ) -end -function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...) - return throw( - ArgumentError( - "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", - ), - ) -end -function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray) - dest.a .= a.a - dest.b .= a.b - return dest -end -for f in [:+, :-] - @eval begin - function Base.map!( - ::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray - ) - if a.b == b.b - map!($f, dest.a, a.a, b.a) - dest.b .= a.b - elseif a.a == b.a - dest.a .= a.a - map!($f, dest.b, a.b, b.b) - else - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) - end - return dest - end - end -end -function Base.map!( - f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray -) - dest.a .= f.f.(f.x, a.a) - dest.b .= a.b - return dest -end -function Base.map!( - f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray -) - dest.a .= a.a - dest.b .= f.f.(a.b, f.x) - return dest -end -function Base.map!( - f::Base.Fix2{typeof(/),<:Number}, dest::KroneckerArray, a::KroneckerArray -) - return map!(Base.Fix2(*, inv(f.x)), dest, a) -end -function Base.map!(::typeof(conj), dest::KroneckerArray, a::KroneckerArray) - dest.a .= conj.(a.a) - dest.b .= conj.(a.b) - return dest -end - -using LinearAlgebra: - LinearAlgebra, - Diagonal, - Eigen, - SVD, - det, - diag, - eigen, - eigvals, - lq, - mul!, - norm, - qr, - svd, - svdvals, - tr - -using DiagonalArrays: DiagonalArrays, diagonal -function DiagonalArrays.diagonal(a::KroneckerArray) - return diagonal(a.a) ⊗ diagonal(a.b) -end - -function Base.:*(a::KroneckerArray, b::KroneckerArray) - return (a.a * b.a) ⊗ (a.b * b.b) -end -function LinearAlgebra.mul!( - c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number -) - iszero(β) || - iszero(c) || - throw( - ArgumentError( - "Can't multiple KroneckerArrays with nonzero β and nonzero destination." - ), - ) - mul!(c.a, a.a, b.a) - mul!(c.b, a.b, b.b, α, β) - return c -end -function LinearAlgebra.tr(a::KroneckerArray) - return tr(a.a) ⊗ tr(a.b) -end -function LinearAlgebra.norm(a::KroneckerArray, p::Int=2) - return norm(a.a, p) ⊗ norm(a.b, p) -end - -function Base.real(a::KroneckerArray) - if iszero(imag(a.a)) || iszero(imag(a.b)) - return real(a.a) ⊗ real(a.b) - elseif iszero(real(a.a)) || iszero(real(a.b)) - return -imag(a.a) ⊗ imag(a.b) - end - return real(a.a) ⊗ real(a.b) - imag(a.a) ⊗ imag(a.b) -end -function Base.imag(a::KroneckerArray) - if iszero(imag(a.a)) || iszero(real(a.b)) - return real(a.a) ⊗ imag(a.b) - elseif iszero(real(a.a)) || iszero(imag(a.b)) - return imag(a.a) ⊗ real(a.b) - end - return real(a.a) ⊗ imag(a.b) + imag(a.a) ⊗ real(a.b) -end - -using MatrixAlgebraKit: MatrixAlgebraKit, diagview -function MatrixAlgebraKit.diagview(a::KroneckerMatrix) - return diagview(a.a) ⊗ diagview(a.b) -end -function LinearAlgebra.diag(a::KroneckerArray) - return copy(diagview(a.a)) ⊗ copy(diagview(a.b)) -end - -# Matrix functions -const MATRIX_FUNCTIONS = [ - :exp, - :cis, - :log, - :sqrt, - :cbrt, - :cos, - :sin, - :tan, - :csc, - :sec, - :cot, - :cosh, - :sinh, - :tanh, - :csch, - :sech, - :coth, - :acos, - :asin, - :atan, - :acsc, - :asec, - :acot, - :acosh, - :asinh, - :atanh, - :acsch, - :asech, - :acoth, -] - -for f in MATRIX_FUNCTIONS - @eval begin - function Base.$f(a::KroneckerArray) - return throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) - end - end -end - -using LinearAlgebra: checksquare -function LinearAlgebra.det(a::KroneckerArray) - checksquare(a.a) - checksquare(a.b) - return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1) -end - -function LinearAlgebra.svd(a::KroneckerArray) - Fa = svd(a.a) - Fb = svd(a.b) - return SVD(Fa.U ⊗ Fb.U, Fa.S ⊗ Fb.S, Fa.Vt ⊗ Fb.Vt) -end -function LinearAlgebra.svdvals(a::KroneckerArray) - return svdvals(a.a) ⊗ svdvals(a.b) -end -function LinearAlgebra.eigen(a::KroneckerArray) - Fa = eigen(a.a) - Fb = eigen(a.b) - return Eigen(Fa.values ⊗ Fb.values, Fa.vectors ⊗ Fb.vectors) -end -function LinearAlgebra.eigvals(a::KroneckerArray) - return eigvals(a.a) ⊗ eigvals(a.b) -end - -struct KroneckerQ{A,B} - a::A - b::B -end -function Base.:*(a::KroneckerQ, b::KroneckerQ) - return (a.a * b.a) ⊗ (a.b * b.b) -end -function Base.:*(a::KroneckerQ, b::KroneckerArray) - return (a.a * b.a) ⊗ (a.b * b.b) -end -function Base.:*(a::KroneckerArray, b::KroneckerQ) - return (a.a * b.a) ⊗ (a.b * b.b) -end -function Base.adjoint(a::KroneckerQ) - return KroneckerQ(a.a', a.b') -end - -struct KroneckerQR{QQ,RR} - Q::QQ - R::RR -end -Base.iterate(F::KroneckerQR) = (F.Q, Val(:R)) -Base.iterate(F::KroneckerQR, ::Val{:R}) = (F.R, Val(:done)) -Base.iterate(F::KroneckerQR, ::Val{:done}) = nothing -function ⊗(a::LinearAlgebra.QRCompactWYQ, b::LinearAlgebra.QRCompactWYQ) - return KroneckerQ(a, b) -end -function LinearAlgebra.qr(a::KroneckerArray) - Fa = qr(a.a) - Fb = qr(a.b) - return KroneckerQR(Fa.Q ⊗ Fb.Q, Fa.R ⊗ Fb.R) -end - -struct KroneckerLQ{LL,QQ} - L::LL - Q::QQ -end -Base.iterate(F::KroneckerLQ) = (F.L, Val(:Q)) -Base.iterate(F::KroneckerLQ, ::Val{:Q}) = (F.Q, Val(:done)) -Base.iterate(F::KroneckerLQ, ::Val{:done}) = nothing -function ⊗(a::LinearAlgebra.LQPackedQ, b::LinearAlgebra.LQPackedQ) - return KroneckerQ(a, b) -end -function LinearAlgebra.lq(a::KroneckerArray) - Fa = lq(a.a) - Fb = lq(a.b) - return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q) -end - -using DerivableInterfaces: DerivableInterfaces, zero! -function DerivableInterfaces.zero!(a::KroneckerArray) - zero!(a.a) - zero!(a.b) - return a -end - -using FillArrays: Eye -const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} -const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} -const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} - -using DerivableInterfaces: DerivableInterfaces, zero! -function DerivableInterfaces.zero!(a::EyeKronecker) - zero!(a.b) - return a -end -function DerivableInterfaces.zero!(a::KroneckerEye) - zero!(a.a) - return a -end -function DerivableInterfaces.zero!(a::EyeEye) - return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) -end - -function Base.:*(a::Number, b::EyeKronecker) - return b.a ⊗ (a * b.b) -end -function Base.:*(a::Number, b::KroneckerEye) - return (a * b.a) ⊗ b.b -end -function Base.:*(a::Number, b::EyeEye) - return (a * b.a) ⊗ b.b -end -function Base.:*(a::EyeKronecker, b::Number) - return a.a ⊗ (a.b * b) -end -function Base.:*(a::KroneckerEye, b::Number) - return (a.a * b) ⊗ a.b -end -function Base.:*(a::EyeEye, b::Number) - return a.a ⊗ (a.b * b) -end - -function Base.:-(a::EyeKronecker) - return a.a ⊗ (-a.b) -end -function Base.:-(a::KroneckerEye) - return (-a.a) ⊗ a.b -end -function Base.:-(a::EyeEye) - return (-a.a) ⊗ a.b -end -for op in (:+, :-) - @eval begin - function Base.$op(a::EyeKronecker, b::EyeKronecker) - if a.a ≠ b.a - return throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end - return a.a ⊗ $op(a.b, b.b) - end - function Base.$op(a::KroneckerEye, b::KroneckerEye) - if a.b ≠ b.b - return throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end - return $op(a.a, b.a) ⊗ a.b - end - function Base.$op(a::EyeEye, b::EyeEye) - if a.b ≠ b.b - return throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end - return $op(a.a, b.a) ⊗ a.b - end - end -end - -function Base.map!(::typeof(identity), dest::EyeKronecker, a::EyeKronecker) - dest.b .= a.b - return dest -end -function Base.map!(::typeof(identity), dest::KroneckerEye, a::KroneckerEye) - dest.a .= a.a - return dest -end -function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye) - return error("Can't write in-place.") -end -for f in [:+, :-] - @eval begin - function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) - if dest.a ≠ a.a ≠ b.a - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) - end - map!($f, dest.b, a.b, b.b) - return dest - end - function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) - if dest.b ≠ a.b ≠ b.b - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) - end - map!($f, dest.a, a.a, b.a) - return dest - end - function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye) - return error("Can't write in-place.") - end - end -end -function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker) - map!(f, dest.b, a.b) - return dest -end -function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye) - map!(f, dest.a, a.a) - return dest -end -function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye) - return error("Can't write in-place.") -end -function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) - dest.b .= f.f.(f.x, a.b) - return dest -end -function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) - dest.a .= f.f.(f.x, a.a) - return dest -end -function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) - return error("Can't write in-place.") -end -function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) - dest.b .= f.f.(a.b, f.x) - return dest -end -function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) - dest.a .= f.f.(a.a, f.x) - return dest -end -function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) - return error("Can't write in-place.") -end - -using MatrixAlgebraKit: - MatrixAlgebraKit, - AbstractAlgorithm, - TruncationStrategy, - default_eig_algorithm, - default_eigh_algorithm, - default_lq_algorithm, - default_polar_algorithm, - default_qr_algorithm, - default_svd_algorithm, - eig_full!, - eig_trunc!, - eig_vals!, - eigh_full!, - eigh_trunc!, - eigh_vals!, - initialize_output, - left_null!, - left_orth!, - left_polar!, - lq_compact!, - lq_full!, - qr_compact!, - qr_full!, - right_null!, - right_orth!, - right_polar!, - svd_compact!, - svd_full!, - svd_trunc!, - svd_vals!, - truncate! - -struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm - a::A - b::B -end - -using MatrixAlgebraKit: - copy_input, - eig_full, - eig_vals, - eigh_full, - eigh_vals, - qr_compact, - qr_full, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full - -for f in [ - :eig_full, - :eigh_full, - :qr_compact, - :qr_full, - :left_polar, - :lq_compact, - :lq_full, - :right_polar, - :svd_compact, - :svd_full, -] - @eval begin - function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) - return copy_input($f, a.a) ⊗ copy_input($f, a.b) - end - end -end - -for f in [ - :default_eig_algorithm, - :default_eigh_algorithm, - :default_lq_algorithm, - :default_qr_algorithm, - :default_polar_algorithm, - :default_svd_algorithm, -] - @eval begin - function MatrixAlgebraKit.$f( - A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs... - ) - A1, A2 = argument_types(A) - return KroneckerAlgorithm( - $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) - ) - end - end -end - -# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. -function MatrixAlgebraKit.default_algorithm( - ::typeof(qr_compact!), A::Type{<:KroneckerMatrix}; kwargs... -) - return default_qr_algorithm(A; kwargs...) -end -# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. -function MatrixAlgebraKit.default_algorithm( - ::typeof(qr_full!), A::Type{<:KroneckerMatrix}; kwargs... -) - return default_qr_algorithm(A; kwargs...) -end - -for f in [ - :eig_full!, - :eigh_full!, - :qr_compact!, - :qr_full!, - :left_polar!, - :lq_compact!, - :lq_full!, - :right_polar!, - :svd_compact!, - :svd_full!, -] - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm - ) - return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f( - a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - $f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) - $f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) - return F - end - end -end - -for f in [:eig_vals!, :eigh_vals!, :svd_vals!] - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm - ) - return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) - $f(a.a, F.a, alg.a) - $f(a.b, F.b, alg.b) - return F - end - end -end - -for f in [:left_orth!, :right_orth!] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) - return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) - end - end -end - -for f in [:left_null!, :right_null!] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) - return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) - end - function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...) - $f(a.a, F.a; kwargs..., kwargs1...) - $f(a.b, F.b; kwargs..., kwargs2...) - return F - end - end -end - -#################################################################################### -# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and -# `A ⊗ Eye(n)` where `A`. -# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34 -# is merged. - -using FillArrays: SquareEye -const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} -const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} -const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} - -using Adapt: Adapt, adapt -Adapt.adapt_structure(to, a::SquareEyeKronecker) = a.a ⊗ adapt(to, a.b) -Adapt.adapt_structure(to, a::KroneckerSquareEye) = adapt(to, a.a) ⊗ a.b -Adapt.adapt_structure(to, a::SquareEyeSquareEye) = adapt(to, a.a) ⊗ a.b - -# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`. -function Base.similar( - a::SquareEyeKronecker, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - ax_a = map(ax -> ax.product.a, axs) - ax_b = map(ax -> ax.product.b, axs) - eye_ax_a = (only(unique(ax_a)),) - return Eye{elt}(eye_ax_a) ⊗ similar(a.b, elt, ax_b) -end -function Base.similar( - a::KroneckerSquareEye, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - ax_a = map(ax -> ax.product.a, axs) - ax_b = map(ax -> ax.product.b, axs) - eye_ax_b = (only(unique(ax_b)),) - return similar(a.a, elt, ax_a) ⊗ Eye{elt}(eye_ax_b) -end -function Base.similar( - a::SquareEyeSquareEye, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, -) - ax_a = map(ax -> ax.product.a, axs) - ax_b = map(ax -> ax.product.b, axs) - eye_ax_a = (only(unique(ax_a)),) - eye_ax_b = (only(unique(ax_b)),) - return Eye{elt}(eye_ax_a) ⊗ Eye{elt}(eye_ax_b) -end - -function Base.similar( - arrayt::Type{<:SquareEyeKronecker{T,A,B}}, - axs::NTuple{2,CartesianProductUnitRange{<:Integer}}, -) where {T,A<:SquareEye{T},B} - ax_a = map(ax -> ax.product.a, axs) - ax_b = map(ax -> ax.product.b, axs) - eye_ax_a = (only(unique(ax_a)),) - return Eye{T}(eye_ax_a) ⊗ similar(B, ax_b) -end -function Base.similar( - arrayt::Type{<:KroneckerSquareEye{T,A,B}}, - axs::NTuple{2,CartesianProductUnitRange{<:Integer}}, -) where {T,A,B<:SquareEye{T}} - ax_a = map(ax -> ax.product.a, axs) - ax_b = map(ax -> ax.product.b, axs) - eye_ax_b = (only(unique(ax_b)),) - return similar(A, ax_a) ⊗ Eye{T}(eye_ax_b) -end -function Base.similar( - arrayt::Type{<:SquareEyeSquareEye}, axs::NTuple{2,CartesianProductUnitRange{<:Integer}} -) - elt = eltype(arrayt) - ax_a = map(ax -> ax.product.a, axs) - ax_b = map(ax -> ax.product.b, axs) - eye_ax_a = (only(unique(ax_a)),) - eye_ax_b = (only(unique(ax_b)),) - return Eye{elt}(eye_ax_a) ⊗ Eye{elt}(eye_ax_b) -end - -struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm - kwargs::KWargs -end -SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...)) - -# Defined to avoid type piracy. -_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a) -_copy_input_squareeye(f::F, a::SquareEye) where {F} = a - -for f in [ - :eig_full, - :eig_vals, - :eigh_full, - :eigh_vals, - :qr_compact, - :qr_full, - :left_null, - :left_orth, - :left_polar, - :lq_compact, - :lq_full, - :right_null, - :right_orth, - :right_polar, - :svd_compact, - :svd_full, -] - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.copy_input(::typeof($f), a::$T) - return _copy_input_squareeye($f, a.a) ⊗ _copy_input_squareeye($f, a.b) - end - end - end -end - -for f in [ - :default_eig_algorithm, - :default_eigh_algorithm, - :default_lq_algorithm, - :default_qr_algorithm, - :default_polar_algorithm, - :default_svd_algorithm, -] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a; kwargs...) = $f(a; kwargs...) - $f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...) - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...) - A1, A2 = argument_types(A) - return KroneckerAlgorithm( - $f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...) - ) - end - end - end -end - -# Defined to avoid type piracy. -_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a) -_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg) - -for f in [:left_null!, :right_null!] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = a - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = a - end -end -for f in [ - :qr_compact!, - :qr_full!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :right_orth!, - :right_polar!, -] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a) - end -end -_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a)) -_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a)) -_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a) -_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a) -for f in [:svd_compact!, :svd_full!] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a) - end -end - -for f in [ - :eig_full!, - :eigh_full!, - :qr_compact!, - :qr_full!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :right_orth!, - :right_polar!, - :svd_compact!, - :svd_full!, -] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) - $f′(a, F, alg::SquareEyeAlgorithm) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) - return _initialize_output_squareeye($f, a.a) .⊗ - _initialize_output_squareeye($f, a.b) - end - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::$T, alg::KroneckerAlgorithm - ) - return _initialize_output_squareeye($f, a.a, alg.a) .⊗ - _initialize_output_squareeye($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f( - a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - $f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) - $f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) - return F - end - end - end -end - -for f in [:left_null!, :right_null!] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a, F; kwargs...) = $f(a, F; kwargs...) - $f′(a::SquareEye, F) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) - return _initialize_output_squareeye($f, a.a) ⊗ _initialize_output_squareeye($f, a.b) - end - function MatrixAlgebraKit.$f(a::$T, F; kwargs1=(;), kwargs2=(;), kwargs...) - $f′(a.a, F.a; kwargs..., kwargs1...) - $f′(a.b, F.b; kwargs..., kwargs2...) - return F - end - end - end -end - -function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::SquareEyeSquareEye) - return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) -end -function MatrixAlgebraKit.left_null!( - a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... -) - return throw(MethodError(left_null!, (a, F))) -end - -function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::SquareEyeSquareEye) - return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) -end -function MatrixAlgebraKit.right_null!( - a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... -) - return throw(MethodError(right_null!, (a, F))) -end - -_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a) -_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a) -for f in [:eigh_vals!, svd_vals!] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a)) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a)) - end -end - -for f in [:eig_vals!, :eigh_vals!, :svd_vals!] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) - $f′(a, F, alg::SquareEyeAlgorithm) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::$T, alg::KroneckerAlgorithm - ) - return _initialize_output_squareeye($f, a.a, alg.a) ⊗ - _initialize_output_squareeye($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f( - a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - $f′(a.a, F.a, alg.a; kwargs..., kwargs1...) - $f′(a.b, F.b, alg.b; kwargs..., kwargs2...) - return F - end - end - end -end - -using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! - -struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy - strategy::T -end - -# Avoid instantiating the identity. -function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2}) - return a.a ⊗ a.b[I[1].b, I[2].b] -end -function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) - return a.a[I[1].a, I[2].a] ⊗ a.b -end -function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2}) - return a -end - -using FillArrays: OnesVector -const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} -const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} -const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} - -function MatrixAlgebraKit.findtruncated( - values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.a, prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> x.a == i, prods) == length(values.a) - end - return (:) × I_data -end -function MatrixAlgebraKit.findtruncated( - values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.b, prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> x.b == i, prods) == length(values.b) - end - return I_data × (:) -end -function MatrixAlgebraKit.findtruncated( - values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy -) - return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) -end - -for f in [:eig_trunc!, :eigh_trunc!] - @eval begin - function MatrixAlgebraKit.truncate!( - ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy - ) - return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) - end - function MatrixAlgebraKit.truncate!( - ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy - ) - I = findtruncated(diagview(D), strategy) - return (D[I, I], V[(:) × (:), I]) - end - end -end - -function MatrixAlgebraKit.truncate!( - f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy -) - return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) -end -function MatrixAlgebraKit.truncate!( - ::typeof(svd_trunc!), - (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, - strategy::KroneckerTruncationStrategy, -) - I = findtruncated(diagview(S), strategy) - return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) -end - -for f in MATRIX_FUNCTIONS - @eval begin - function Base.$f(a::SquareEyeKronecker) - return a.a ⊗ $f(a.b) - end - function Base.$f(a::KroneckerSquareEye) - return $f(a.a) ⊗ a.b - end - function Base.$f(a::SquareEyeSquareEye) - return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported.")) - end - end -end - -function LinearAlgebra.pinv(a::SquareEyeKronecker; kwargs...) - return a.a ⊗ pinv(a.b; kwargs...) -end -function LinearAlgebra.pinv(a::KroneckerSquareEye; kwargs...) - return pinv(a.a; kwargs...) ⊗ a.b -end -function LinearAlgebra.pinv(a::SquareEyeSquareEye; kwargs...) - return a -end +include("cartesianproduct.jl") +include("kroneckerarray.jl") +include("linearalgebra.jl") +include("matrixalgebrakit.jl") +include("fillarrays/kroneckerarray.jl") +include("fillarrays/linearalgebra.jl") +include("fillarrays/matrixalgebrakit.jl") end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl new file mode 100644 index 0000000..021c7c0 --- /dev/null +++ b/src/cartesianproduct.jl @@ -0,0 +1,71 @@ +struct CartesianProduct{A,B} + a::A + b::B +end +arguments(a::CartesianProduct) = (a.a, a.b) +arguments(a::CartesianProduct, n::Int) = arguments(a)[n] + +function Base.show(io::IO, a::CartesianProduct) + print(io, a.a, " × ", a.b) + return nothing +end + +×(a, b) = CartesianProduct(a, b) +Base.length(a::CartesianProduct) = length(a.a) * length(a.b) +Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b] + +function Base.iterate(a::CartesianProduct, state...) + x = iterate(Iterators.product(a.a, a.b), state...) + isnothing(x) && return x + next, new_state = x + return ×(next...), new_state +end + +struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <: + AbstractUnitRange{T} + product::P + range::R +end +Base.first(r::CartesianProductUnitRange) = first(r.range) +Base.last(r::CartesianProductUnitRange) = last(r.range) + +cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product) +unproduct(r::CartesianProductUnitRange) = getfield(r, :range) + +function CartesianProductUnitRange(p::CartesianProduct) + return CartesianProductUnitRange(p, Base.OneTo(length(p))) +end +function CartesianProductUnitRange(a, b) + return CartesianProductUnitRange(a × b) +end +to_range(a::AbstractUnitRange) = a +to_range(i::Integer) = Base.OneTo(i) +cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b)) +function cartesianrange(p::CartesianProduct) + p′ = to_range(p.a) × to_range(p.b) + return cartesianrange(p′, Base.OneTo(length(p′))) +end +function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) + p′ = to_range(p.a) × to_range(p.b) + return CartesianProductUnitRange(p′, range) +end + +function Base.axes(r::CartesianProductUnitRange) + return (CartesianProductUnitRange(r.product, only(axes(r.range))),) +end + +using Base.Broadcast: DefaultArrayStyle +for f in (:+, :-) + @eval begin + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer + ) + return CartesianProductUnitRange(r.product, $f.(r.range, x)) + end + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange + ) + return CartesianProductUnitRange(r.product, $f.(x, r.range)) + end + end +end diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl new file mode 100644 index 0000000..1030602 --- /dev/null +++ b/src/fillarrays/kroneckerarray.jl @@ -0,0 +1,182 @@ +using FillArrays: Eye +const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} +const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} +const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} + +using FillArrays: SquareEye +const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} +const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} +const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} + +# Like `adapt` but preserves `Eye`. +_adapt(to, a::Eye) = a + +# Like `similar` but preserves `Eye`. +function _similar(a::Eye, elt::Type, axs::NTuple{2,AbstractUnitRange}) + return Eye{elt}(axs) +end +function _similar(arrayt::Type{<:Eye}, axs::NTuple{2,AbstractUnitRange}) + return Eye{eltype(arrayt)}(axs) +end + +# Like `similar` but preserves `SquareEye`. +function _similar(a::SquareEye, elt::Type, axs::NTuple{2,AbstractUnitRange}) + return Eye{elt}((only(unique(axs)),)) +end +function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange}) + return Eye{eltype(arrayt)}((only(unique(axs)),)) +end + +# Like `copy` but preserves `Eye`. +_copy(a::Eye) = a + +using DerivableInterfaces: DerivableInterfaces, zero! +function DerivableInterfaces.zero!(a::EyeKronecker) + zero!(a.b) + return a +end +function DerivableInterfaces.zero!(a::KroneckerEye) + zero!(a.a) + return a +end +function DerivableInterfaces.zero!(a::EyeEye) + return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) +end + +function Base.:*(a::Number, b::EyeKronecker) + return b.a ⊗ (a * b.b) +end +function Base.:*(a::Number, b::KroneckerEye) + return (a * b.a) ⊗ b.b +end +function Base.:*(a::Number, b::EyeEye) + return error("Can't multiply `Eye ⊗ Eye` by a number.") +end +function Base.:*(a::EyeKronecker, b::Number) + return a.a ⊗ (a.b * b) +end +function Base.:*(a::KroneckerEye, b::Number) + return (a.a * b) ⊗ a.b +end +function Base.:*(a::EyeEye, b::Number) + return error("Can't multiply `Eye ⊗ Eye` by a number.") +end + +function Base.:-(a::EyeKronecker) + return a.a ⊗ (-a.b) +end +function Base.:-(a::KroneckerEye) + return (-a.a) ⊗ a.b +end +function Base.:-(a::EyeEye) + return error("Can't multiply `Eye ⊗ Eye` by a number.") +end + +for op in (:+, :-) + @eval begin + function Base.$op(a::EyeKronecker, b::EyeKronecker) + if a.a ≠ b.a + return throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or secord arguments match.", + ), + ) + end + return a.a ⊗ $op(a.b, b.b) + end + function Base.$op(a::KroneckerEye, b::KroneckerEye) + if a.b ≠ b.b + return throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or secord arguments match.", + ), + ) + end + return $op(a.a, b.a) ⊗ a.b + end + function Base.$op(a::EyeEye, b::EyeEye) + if a.b ≠ b.b + return throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or secord arguments match.", + ), + ) + end + return $op(a.a, b.a) ⊗ a.b + end + end +end + +function Base.map!(f::typeof(identity), dest::EyeKronecker, a::EyeKronecker) + map!(f, dest.b, src.b) + return dest +end +function Base.map!(f::typeof(identity), dest::KroneckerEye, a::KroneckerEye) + map!(f, dest.a, src.a) + return dest +end +function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye) + return error("Can't write in-place.") +end +for f in [:+, :-] + @eval begin + function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) + if dest.a ≠ a.a ≠ b.a + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + map!($f, dest.b, a.b, b.b) + return dest + end + function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) + if dest.b ≠ a.b ≠ b.b + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + map!($f, dest.a, a.a, b.a) + return dest + end + function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye) + return error("Can't write in-place.") + end + end +end +function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker) + map!(f, dest.b, a.b) + return dest +end +function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye) + map!(f, dest.a, a.a) + return dest +end +function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye) + return error("Can't write in-place.") +end +function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) + map!(f, dest.b, a.b) + return dest +end +function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) + map!(f, dest.a, a.a) + return dest +end +function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) + return error("Can't write in-place.") +end +function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) + map!(f, dest.b, a.b) + return dest +end +function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) + map!(f, dest.a, a.a) + return dest +end +function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) + return error("Can't write in-place.") +end diff --git a/src/fillarrays/linearalgebra.jl b/src/fillarrays/linearalgebra.jl new file mode 100644 index 0000000..ed2cd6a --- /dev/null +++ b/src/fillarrays/linearalgebra.jl @@ -0,0 +1,75 @@ +using FillArrays: Eye, SquareEye +using LinearAlgebra: LinearAlgebra, mul!, pinv + +function check_mul_axes(a::AbstractMatrix, b::AbstractMatrix) + return axes(a, 2) == axes(b, 1) || throw(DimensionMismatch("Incompatible matrix sizes.")) +end + +function _mul(a::Eye, b::Eye) + check_mul_axes(a, b) + T = promote_type(eltype(a), eltype(b)) + return Eye{T}((axes(a, 1), axes(b, 2))) +end +function _mul(a::SquareEye, b::SquareEye) + check_mul_axes(a, b) + return Diagonal(diagview(a) .* diagview(b)) +end + +for f in MATRIX_FUNCTIONS + @eval begin + function Base.$f(a::EyeKronecker) + LinearAlgebra.checksquare(a.a) + return a.a ⊗ $f(a.b) + end + function Base.$f(a::KroneckerEye) + LinearAlgebra.checksquare(a.b) + return $f(a.a) ⊗ a.b + end + function Base.$f(a::EyeEye) + LinearAlgebra.checksquare(a) + return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported.")) + end + end +end + +function LinearAlgebra.mul!( + c::EyeKronecker, a::EyeKronecker, b::EyeKronecker, α::Number, β::Number +) + iszero(β) || + iszero(c) || + throw( + ArgumentError( + "Can't multiple KroneckerArrays with nonzero β and nonzero destination." + ), + ) + check_mul_axes(a.a, b.a) + mul!(c.b, a.b, b.b, α, β) + return c +end +function LinearAlgebra.mul!( + c::KroneckerEye, a::KroneckerEye, b::KroneckerEye, α::Number, β::Number +) + iszero(β) || + iszero(c) || + throw( + ArgumentError( + "Can't multiple KroneckerArrays with nonzero β and nonzero destination." + ), + ) + check_mul_axes(a.b, b.b) + mul!(c.a, a.a, b.a, α, β) + return c +end +function LinearAlgebra.mul!(c::EyeEye, a::EyeEye, b::EyeEye, α::Number, β::Number) + return throw(ArgumentError("Can't multiple `Eye ⊗ Eye` in-place.")) +end + +function LinearAlgebra.pinv(a::EyeKronecker; kwargs...) + return a.a ⊗ pinv(a.b; kwargs...) +end +function LinearAlgebra.pinv(a::KroneckerEye; kwargs...) + return pinv(a.a; kwargs...) ⊗ a.b +end +function LinearAlgebra.pinv(a::EyeEye; kwargs...) + return a +end diff --git a/src/fillarrays/matrixalgebrakit.jl b/src/fillarrays/matrixalgebrakit.jl new file mode 100644 index 0000000..1f82d24 --- /dev/null +++ b/src/fillarrays/matrixalgebrakit.jl @@ -0,0 +1,298 @@ +#################################################################################### +# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and +# `A ⊗ Eye(n)` where `A`. +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34 +# is merged. + +struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm + kwargs::KWargs +end +SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...)) + +# Defined to avoid type piracy. +_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a) +_copy_input_squareeye(f::F, a::SquareEye) where {F} = a + +for f in [ + :eig_full, + :eig_vals, + :eigh_full, + :eigh_vals, + :qr_compact, + :qr_full, + :left_null, + :left_orth, + :left_polar, + :lq_compact, + :lq_full, + :right_null, + :right_orth, + :right_polar, + :svd_compact, + :svd_full, +] + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.copy_input(::typeof($f), a::$T) + return _copy_input_squareeye($f, a.a) ⊗ _copy_input_squareeye($f, a.b) + end + end + end +end + +for f in [ + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, +] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a; kwargs...) = $f(a; kwargs...) + $f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...) + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...) + A1, A2 = argument_types(A) + return KroneckerAlgorithm( + $f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...) + ) + end + end + end +end + +# Defined to avoid type piracy. +_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a) +_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg) + +for f in [:left_null!, :right_null!] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = a + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = a + end +end +for f in [ + :qr_compact!, + :qr_full!, + :left_orth!, + :left_polar!, + :lq_compact!, + :lq_full!, + :right_orth!, + :right_polar!, +] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a) + end +end +_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a)) +_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a)) +_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a) +_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a) +for f in [:svd_compact!, :svd_full!] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a) + end +end + +for f in [ + :eig_full!, + :eigh_full!, + :qr_compact!, + :qr_full!, + :left_orth!, + :left_polar!, + :lq_compact!, + :lq_full!, + :right_orth!, + :right_polar!, + :svd_compact!, + :svd_full!, +] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) + $f′(a, F, alg::SquareEyeAlgorithm) = F + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) + return _initialize_output_squareeye($f, a.a) .⊗ + _initialize_output_squareeye($f, a.b) + end + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::$T, alg::KroneckerAlgorithm + ) + return _initialize_output_squareeye($f, a.a, alg.a) .⊗ + _initialize_output_squareeye($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f( + a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... + ) + $f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) + $f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) + return F + end + end + end +end + +for f in [:left_null!, :right_null!] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a, F; kwargs...) = $f(a, F; kwargs...) + $f′(a::SquareEye, F) = F + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye] + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) + return _initialize_output_squareeye($f, a.a) ⊗ _initialize_output_squareeye($f, a.b) + end + function MatrixAlgebraKit.$f(a::$T, F; kwargs1=(;), kwargs2=(;), kwargs...) + $f′(a.a, F.a; kwargs..., kwargs1...) + $f′(a.b, F.b; kwargs..., kwargs2...) + return F + end + end + end +end + +function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::SquareEyeSquareEye) + return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) +end +function MatrixAlgebraKit.left_null!( + a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... +) + return throw(MethodError(left_null!, (a, F))) +end + +function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::SquareEyeSquareEye) + return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) +end +function MatrixAlgebraKit.right_null!( + a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... +) + return throw(MethodError(right_null!, (a, F))) +end + +_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a) +_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a) +for f in [:eigh_vals!, svd_vals!] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a)) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a)) + end +end + +for f in [:eig_vals!, :eigh_vals!, :svd_vals!] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) + $f′(a, F, alg::SquareEyeAlgorithm) = F + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::$T, alg::KroneckerAlgorithm + ) + return _initialize_output_squareeye($f, a.a, alg.a) ⊗ + _initialize_output_squareeye($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f( + a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... + ) + $f′(a.a, F.a, alg.a; kwargs..., kwargs1...) + $f′(a.b, F.b, alg.b; kwargs..., kwargs2...) + return F + end + end + end +end + +using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! + +struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy + strategy::T +end + +# Avoid instantiating the identity. +function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2}) + return a.a ⊗ a.b[I[1].b, I[2].b] +end +function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) + return a.a[I[1].a, I[2].a] ⊗ a.b +end +function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2}) + return a +end + +using FillArrays: OnesVector +const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} +const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} +const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} + +function MatrixAlgebraKit.findtruncated( + values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.a, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.a == i, prods) == length(values.a) + end + return (:) × I_data +end +function MatrixAlgebraKit.findtruncated( + values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.b, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.b == i, prods) == length(values.b) + end + return I_data × (:) +end +function MatrixAlgebraKit.findtruncated( + values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy +) + return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) +end + +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy + ) + return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) + end + function MatrixAlgebraKit.truncate!( + ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy + ) + I = findtruncated(diagview(D), strategy) + return (D[I, I], V[(:) × (:), I]) + end + end +end + +function MatrixAlgebraKit.truncate!( + f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy +) + return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) +end +function MatrixAlgebraKit.truncate!( + ::typeof(svd_trunc!), + (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, + strategy::KroneckerTruncationStrategy, +) + I = findtruncated(diagview(S), strategy) + return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) +end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl new file mode 100644 index 0000000..e2e770e --- /dev/null +++ b/src/kroneckerarray.jl @@ -0,0 +1,365 @@ +struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N} + a::A + b::B +end +function KroneckerArray(a::AbstractArray, b::AbstractArray) + if ndims(a) != ndims(b) + throw( + ArgumentError("Kronecker product requires arrays of the same number of dimensions.") + ) + end + elt = promote_type(eltype(a), eltype(b)) + return KroneckerArray(convert(AbstractArray{elt}, a), convert(AbstractArray{elt}, b)) +end +const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B} +const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B} + +using Adapt: Adapt, adapt +_adapt(to, a::AbstractArray) = adapt(to, a) +Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, a.a) ⊗ _adapt(to, a.b) + +# Allows extra customization, like for `FillArrays.Eye`. +_copy(a::AbstractArray) = copy(a) + +function Base.copy(a::KroneckerArray) + return _copy(a.a) ⊗ _copy(a.b) +end +function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) + copyto!(dest.a, src.a) + copyto!(dest.b, src.b) + return dest +end + +# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`. +function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}}) + return similar(a, elt, axs) +end +function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple{Vararg{AbstractUnitRange}}) + return similar(arrayt, axs) +end + +function Base.similar( + a::AbstractArray, + elt::Type, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) + return _similar(a, elt, map(ax -> ax.product.a, axs)) ⊗ + _similar(a, elt, map(ax -> ax.product.b, axs)) +end +function Base.similar( + a::KroneckerArray, + elt::Type, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) + return _similar(a.a, elt, map(ax -> ax.product.a, axs)) ⊗ + _similar(a.b, elt, map(ax -> ax.product.b, axs)) +end +function Base.similar( + arrayt::Type{<:AbstractArray}, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) + return _similar(arrayt, map(ax -> ax.product.a, axs)) ⊗ + _similar(arrayt, map(ax -> ax.product.b, axs)) +end +function Base.similar( + arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) where {A,B} + return _similar(A, map(ax -> ax.product.a, axs)) ⊗ + _similar(B, map(ax -> ax.product.b, axs)) +end +function Base.similar( + ::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}} +) where {A,B} + return similar(promote_type(A, B), sz) +end + +function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}}) + return (t[1]..., flatten(Base.tail(t))...) +end +function flatten(t::Tuple{Tuple}) + return t[1] +end +flatten(::Tuple{}) = () +function interleave(x::Tuple, y::Tuple) + length(x) == length(y) || throw(ArgumentError("Tuples must have the same length.")) + xy = ntuple(i -> (x[i], y[i]), length(x)) + return flatten(xy) +end +# TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing: +# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 +function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N} + a′ = reshape(a, interleave(size(a), ntuple(one, N))) + b′ = reshape(b, interleave(ntuple(one, N), size(b))) + c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N))) + sz = ntuple(i -> size(a, i) * size(b, i), N) + return permutedims(reshape(c′, sz), reverse(ntuple(identity, N))) +end +kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b) +kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b) + +Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b) + +function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} + return convert(Array{T,N}, collect(a)) +end + +Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a)) + +function Base.axes(a::KroneckerArray) + return ntuple(ndims(a)) do dim + return CartesianProductUnitRange( + axes(a.a, dim) × axes(a.b, dim), Base.OneTo(size(a, dim)) + ) + end +end + +arguments(a::KroneckerArray) = (a.a, a.b) +arguments(a::KroneckerArray, n::Int) = arguments(a)[n] +argument_types(a::KroneckerArray) = argument_types(typeof(a)) +argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B) + +function Base.print_array(io::IO, a::KroneckerArray) + Base.print_array(io, a.a) + println(io, "\n ⊗") + Base.print_array(io, a.b) + return nothing +end +function Base.show(io::IO, a::KroneckerArray) + show(io, a.a) + print(io, " ⊗ ") + show(io, a.b) + return nothing +end + +⊗(a::AbstractArray, b::AbstractArray) = KroneckerArray(a, b) +⊗(a::Number, b::Number) = a * b +⊗(a::Number, b::AbstractArray) = a * b +⊗(a::AbstractArray, b::Number) = a * b + +function Base.getindex(a::KroneckerArray, i::Integer) + return a[CartesianIndices(a)[i]] +end + +# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing +# in the n-dimensional case and use it to replace the matrix and vector cases: +# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} + return error("Not implemented.") +end + +using GPUArraysCore: GPUArraysCore +function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) + GPUArraysCore.assertscalar("getindex") + # Code logic from Kronecker.jl: + # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 + k, l = size(a.b) + return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] +end + +function Base.getindex(a::KroneckerVector, i::Integer) + GPUArraysCore.assertscalar("getindex") + k = length(a.b) + return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] +end + +## function Base.getindex(a::KroneckerVector, i::CartesianProduct) +## return a.a[i.a] ⊗ a.b[i.b] +## end +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} + return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] +end +# Fix ambigiuity error. +Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[] + +function Base.:(==)(a::KroneckerArray, b::KroneckerArray) + return a.a == b.a && a.b == b.b +end +function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...) + return isapprox(a.a, b.a; kwargs...) && isapprox(a.b, b.b; kwargs...) +end +function Base.iszero(a::KroneckerArray) + return iszero(a.a) || iszero(a.b) +end +function Base.isreal(a::KroneckerArray) + return isreal(a.a) && isreal(a.b) +end + +for f in [:transpose, :adjoint, :inv] + @eval begin + function Base.$f(a::KroneckerArray) + return $f(a.a) ⊗ $f(a.b) + end + end +end + +function Base.:*(a::Number, b::KroneckerArray) + return (a * b.a) ⊗ b.b +end +function Base.:*(a::KroneckerArray, b::Number) + return a.a ⊗ (a.b * b) +end +function Base.:/(a::KroneckerArray, b::Number) + return a.a ⊗ (a.b / b) +end +function Base.:-(a::KroneckerArray) + return (-a.a) ⊗ a.b +end + +for op in (:+, :-) + @eval begin + function Base.$op(a::KroneckerArray, b::KroneckerArray) + if a.b == b.b + return $op(a.a, b.a) ⊗ a.b + elseif a.a == b.a + return a.a ⊗ $op(a.b, b.b) + else + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or secord arguments match.", + ), + ) + end + end + end +end + +using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted +struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end +function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N} + return KroneckerStyle{N,a,b}() +end +function KroneckerStyle(a::AbstractArrayStyle{N}, b::AbstractArrayStyle{N}) where {N} + return KroneckerStyle{N}(a, b) +end +function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M} + return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}() +end +function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B} + return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B)) +end +function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N} + return KroneckerStyle{N}( + BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b) + ) +end +function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B} + ax_a = map(ax -> ax.product.a, axes(bc)) + ax_b = map(ax -> ax.product.b, axes(bc)) + bc_a = Broadcasted(A, nothing, (), ax_a) + bc_b = Broadcasted(B, nothing, (), ax_b) + a = similar(bc_a, elt) + b = similar(bc_b, elt) + return a ⊗ b +end +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle}) + return throw( + ArgumentError( + "Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", + ), + ) +end + +function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) + return throw( + ArgumentError( + "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", + ), + ) +end +function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...) + return throw( + ArgumentError( + "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", + ), + ) +end + +function _map!!(f::F, dest::AbstractArray, srcs::AbstractArray...) where {F} + map!(f, dest, srcs...) + return dest +end + +for f in [:identity, :conj] + @eval begin + function Base.map!(::typeof($f), dest::KroneckerArray, src::KroneckerArray) + _map!!($f, dest.a, src.a) + _map!!($f, dest.b, src.b) + return dest + end + end +end + +for f in [:+, :-] + @eval begin + function Base.map!( + ::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray + ) + if a.b == b.b + map!($f, dest.a, a.a, b.a) + map!(identity, dest.b, a.b) + return dest + elseif a.a == b.a + map!(identity, dest.a, a.a) + map!($f, dest.b, a.b, b.b) + return dest + else + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + end + end +end + +function Base.map!( + f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, src::KroneckerArray +) + map!(f, dest.a, src.a) + map!(identity, dest.b, src.b) + return dest +end + +for op in [:*, :/] + @eval begin + function Base.map!( + f::Base.Fix2{typeof($op),<:Number}, dest::KroneckerArray, src::KroneckerArray + ) + map!(identity, dest.a, src.a) + map!(f, dest.b, src.b) + return dest + end + end +end + +using DiagonalArrays: DiagonalArrays, diagonal +function DiagonalArrays.diagonal(a::KroneckerArray) + return diagonal(a.a) ⊗ diagonal(a.b) +end + +function Base.real(a::KroneckerArray) + if iszero(imag(a.a)) || iszero(imag(a.b)) + return real(a.a) ⊗ real(a.b) + elseif iszero(real(a.a)) || iszero(real(a.b)) + return -imag(a.a) ⊗ imag(a.b) + end + return real(a.a) ⊗ real(a.b) - imag(a.a) ⊗ imag(a.b) +end +function Base.imag(a::KroneckerArray) + if iszero(imag(a.a)) || iszero(real(a.b)) + return real(a.a) ⊗ imag(a.b) + elseif iszero(real(a.a)) || iszero(imag(a.b)) + return imag(a.a) ⊗ real(a.b) + end + return real(a.a) ⊗ imag(a.b) + imag(a.a) ⊗ real(a.b) +end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl new file mode 100644 index 0000000..2218363 --- /dev/null +++ b/src/linearalgebra.jl @@ -0,0 +1,179 @@ +using LinearAlgebra: + LinearAlgebra, + Diagonal, + Eigen, + SVD, + det, + diag, + eigen, + eigvals, + lq, + mul!, + norm, + qr, + svd, + svdvals, + tr + +using LinearAlgebra: LinearAlgebra, pinv +function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) + return pinv(a.a; kwargs...) ⊗ pinv(a.b; kwargs...) +end + +function LinearAlgebra.diag(a::KroneckerArray) + return copy(diagview(a)) +end + +# Allows customizing multiplication for specific types +# such as `Eye * Eye`, which doesn't return `Eye`. +function _mul(a::AbstractArray, b::AbstractArray) + return a * b +end + +function Base.:*(a::KroneckerArray, b::KroneckerArray) + return _mul(a.a, b.a) ⊗ _mul(a.b, b.b) +end + +function LinearAlgebra.mul!( + c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number +) + iszero(β) || + iszero(c) || + throw( + ArgumentError( + "Can't multiple KroneckerArrays with nonzero β and nonzero destination." + ), + ) + mul!(c.a, a.a, b.a) + mul!(c.b, a.b, b.b, α, β) + return c +end + +function LinearAlgebra.tr(a::KroneckerArray) + return tr(a.a) ⊗ tr(a.b) +end + +function LinearAlgebra.norm(a::KroneckerArray, p::Int=2) + return norm(a.a, p) ⊗ norm(a.b, p) +end + +# Matrix functions +const MATRIX_FUNCTIONS = [ + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + +for f in MATRIX_FUNCTIONS + @eval begin + function Base.$f(a::KroneckerArray) + return throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) + end + end +end + +using LinearAlgebra: checksquare +function LinearAlgebra.det(a::KroneckerArray) + checksquare(a.a) + checksquare(a.b) + return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1) +end + +function LinearAlgebra.svd(a::KroneckerArray) + Fa = svd(a.a) + Fb = svd(a.b) + return SVD(Fa.U ⊗ Fb.U, Fa.S ⊗ Fb.S, Fa.Vt ⊗ Fb.Vt) +end +function LinearAlgebra.svdvals(a::KroneckerArray) + return svdvals(a.a) ⊗ svdvals(a.b) +end +function LinearAlgebra.eigen(a::KroneckerArray) + Fa = eigen(a.a) + Fb = eigen(a.b) + return Eigen(Fa.values ⊗ Fb.values, Fa.vectors ⊗ Fb.vectors) +end +function LinearAlgebra.eigvals(a::KroneckerArray) + return eigvals(a.a) ⊗ eigvals(a.b) +end + +struct KroneckerQ{A,B} + a::A + b::B +end +function Base.:*(a::KroneckerQ, b::KroneckerQ) + return (a.a * b.a) ⊗ (a.b * b.b) +end +function Base.:*(a::KroneckerQ, b::KroneckerArray) + return (a.a * b.a) ⊗ (a.b * b.b) +end +function Base.:*(a::KroneckerArray, b::KroneckerQ) + return (a.a * b.a) ⊗ (a.b * b.b) +end +function Base.adjoint(a::KroneckerQ) + return KroneckerQ(a.a', a.b') +end + +struct KroneckerQR{QQ,RR} + Q::QQ + R::RR +end +Base.iterate(F::KroneckerQR) = (F.Q, Val(:R)) +Base.iterate(F::KroneckerQR, ::Val{:R}) = (F.R, Val(:done)) +Base.iterate(F::KroneckerQR, ::Val{:done}) = nothing +function ⊗(a::LinearAlgebra.QRCompactWYQ, b::LinearAlgebra.QRCompactWYQ) + return KroneckerQ(a, b) +end +function LinearAlgebra.qr(a::KroneckerArray) + Fa = qr(a.a) + Fb = qr(a.b) + return KroneckerQR(Fa.Q ⊗ Fb.Q, Fa.R ⊗ Fb.R) +end + +struct KroneckerLQ{LL,QQ} + L::LL + Q::QQ +end +Base.iterate(F::KroneckerLQ) = (F.L, Val(:Q)) +Base.iterate(F::KroneckerLQ, ::Val{:Q}) = (F.Q, Val(:done)) +Base.iterate(F::KroneckerLQ, ::Val{:done}) = nothing +function ⊗(a::LinearAlgebra.LQPackedQ, b::LinearAlgebra.LQPackedQ) + return KroneckerQ(a, b) +end +function LinearAlgebra.lq(a::KroneckerArray) + Fa = lq(a.a) + Fb = lq(a.b) + return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q) +end + +using DerivableInterfaces: DerivableInterfaces, zero! +function DerivableInterfaces.zero!(a::KroneckerArray) + zero!(a.a) + zero!(a.b) + return a +end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl new file mode 100644 index 0000000..5cf8b77 --- /dev/null +++ b/src/matrixalgebrakit.jl @@ -0,0 +1,177 @@ +using MatrixAlgebraKit: + MatrixAlgebraKit, + AbstractAlgorithm, + TruncationStrategy, + default_eig_algorithm, + default_eigh_algorithm, + default_lq_algorithm, + default_polar_algorithm, + default_qr_algorithm, + default_svd_algorithm, + eig_full!, + eig_trunc!, + eig_vals!, + eigh_full!, + eigh_trunc!, + eigh_vals!, + initialize_output, + left_null!, + left_orth!, + left_polar!, + lq_compact!, + lq_full!, + qr_compact!, + qr_full!, + right_null!, + right_orth!, + right_polar!, + svd_compact!, + svd_full!, + svd_trunc!, + svd_vals!, + truncate! + +using MatrixAlgebraKit: MatrixAlgebraKit, diagview +function MatrixAlgebraKit.diagview(a::KroneckerMatrix) + return diagview(a.a) ⊗ diagview(a.b) +end + +struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm + a::A + b::B +end + +using MatrixAlgebraKit: + copy_input, + eig_full, + eig_vals, + eigh_full, + eigh_vals, + qr_compact, + qr_full, + left_null, + left_orth, + left_polar, + lq_compact, + lq_full, + right_null, + right_orth, + right_polar, + svd_compact, + svd_full + +for f in [ + :eig_full, + :eigh_full, + :qr_compact, + :qr_full, + :left_polar, + :lq_compact, + :lq_full, + :right_polar, + :svd_compact, + :svd_full, +] + @eval begin + function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) + return copy_input($f, a.a) ⊗ copy_input($f, a.b) + end + end +end + +for f in [ + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, +] + @eval begin + function MatrixAlgebraKit.$f( + A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs... + ) + A1, A2 = argument_types(A) + return KroneckerAlgorithm( + $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) + ) + end + end +end + +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_algorithm( + ::typeof(qr_compact!), A::Type{<:KroneckerMatrix}; kwargs... +) + return default_qr_algorithm(A; kwargs...) +end +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_algorithm( + ::typeof(qr_full!), A::Type{<:KroneckerMatrix}; kwargs... +) + return default_qr_algorithm(A; kwargs...) +end + +for f in [ + :eig_full!, + :eigh_full!, + :qr_compact!, + :qr_full!, + :left_polar!, + :lq_compact!, + :lq_full!, + :right_polar!, + :svd_compact!, + :svd_full!, +] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm + ) + return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f( + a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... + ) + $f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) + $f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) + return F + end + end +end + +for f in [:eig_vals!, :eigh_vals!, :svd_vals!] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm + ) + return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) + $f(a.a, F.a, alg.a) + $f(a.b, F.b, alg.b) + return F + end + end +end + +for f in [:left_orth!, :right_orth!] + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) + end + end +end + +for f in [:left_null!, :right_null!] + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...) + $f(a.a, F.a; kwargs..., kwargs1...) + $f(a.b, F.b; kwargs..., kwargs2...) + return F + end + end +end diff --git a/test/Project.toml b/test/Project.toml index e09bdbb..8d55cb0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,7 +19,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Adapt = "4" Aqua = "0.8" BlockArrays = "1.6" -BlockSparseArrays = "0.7" +BlockSparseArrays = "0.7.12" DerivableInterfaces = "0.5" FillArrays = "1" JLArrays = "0.2" diff --git a/test/test_basics.jl b/test/test_basics.jl index aa674f2..45f6403 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,6 @@ using Adapt: adapt using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted using DerivableInterfaces: zero! -using FillArrays: Eye using JLArrays: JLArray using KroneckerArrays: KroneckerArrays, @@ -18,6 +17,7 @@ using KroneckerArrays: using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset +using TestExtras: @constinferred elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "KroneckerArrays (eltype=$elt)" for elt in elts @@ -25,18 +25,18 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test length(p) == 6 @test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5] - r = cartesianrange(2, 3) + r = @constinferred cartesianrange(2, 3) @test r === - cartesianrange(2 × 3) === - cartesianrange(Base.OneTo(2), Base.OneTo(3)) === - cartesianrange(Base.OneTo(2) × Base.OneTo(3)) - @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) + @constinferred(cartesianrange(2 × 3)) === + @constinferred(cartesianrange(Base.OneTo(2), Base.OneTo(3))) === + @constinferred(cartesianrange(Base.OneTo(2) × Base.OneTo(3))) + @test @constinferred(cartesianproduct(r)) === Base.OneTo(2) × Base.OneTo(3) @test unproduct(r) === Base.OneTo(6) @test length(r) == 6 @test first(r) == 1 @test last(r) == 6 - r = cartesianrange(2 × 3, 2:7) + r = @constinferred(cartesianrange(2 × 3, 2:7)) @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) @test unproduct(r) === 2:7 @@ -44,7 +44,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test first(r) == 2 @test last(r) == 7 - a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c = a.a ⊗ b.b @test a isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)} @@ -182,189 +182,3 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) end end end - -@testset "FillArrays.Eye" begin - MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS - if VERSION < v"1.11-" - # `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11. - MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt]) - end - - a = Eye(2) ⊗ randn(3, 3) - @test size(a) == (6, 6) - @test a + a == Eye(2) ⊗ (2a.b) - @test 2a == Eye(2) ⊗ (2a.b) - @test a * a == Eye(2) ⊗ (a.b * a.b) - - a = randn(3, 3) ⊗ Eye(2) - @test size(a) == (6, 6) - @test a + a == (2a.a) ⊗ Eye(2) - @test 2a == (2a.a) ⊗ Eye(2) - @test a * a == (a.a * a.a) ⊗ Eye(2) - - # similar - a = Eye(2) ⊗ randn(3, 3) - for a′ in ( - similar(a), - similar(a, eltype(a)), - similar(a, axes(a)), - similar(a, eltype(a), axes(a)), - similar(typeof(a), axes(a)), - ) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} - @test a′.a === a.a - end - - a = Eye(2) ⊗ randn(3, 3) - for args in ((Float32,), (Float32, axes(a))) - a′ = similar(a, args...) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - @test a′.a === Eye{Float32}(2) - end - - a = randn(3, 3) ⊗ Eye(2) - for a′ in ( - similar(a), - similar(a, eltype(a)), - similar(a, axes(a)), - similar(a, eltype(a), axes(a)), - similar(typeof(a), axes(a)), - ) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} - @test a′.b === a.b - end - - a = randn(3, 3) ⊗ Eye(2) - for args in ((Float32,), (Float32, axes(a))) - a′ = similar(a, args...) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - @test a′.b === Eye{Float32}(2) - end - - a = Eye(3) ⊗ Eye(2) - for a′ in ( - similar(a), - similar(a, eltype(a)), - similar(a, axes(a)), - similar(a, eltype(a), axes(a)), - similar(typeof(a), axes(a)), - ) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} - @test a′.a === a.a - @test a′.b === a.b - end - - a = Eye(3) ⊗ Eye(2) - for args in ((Float32,), (Float32, axes(a))) - a′ = similar(a, args...) - @test size(a′) == (6, 6) - @test a′ isa KroneckerArray{Float32,ndims(a)} - @test a′.a === Eye{Float32}(3) - @test a′.b === Eye{Float32}(2) - end - - # DerivableInterfaces.zero! - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) - zero!(a) - @test iszero(a) - end - a = Eye(3) ⊗ Eye(2) - @test_throws ArgumentError zero!(a) - - # map!(+, ...) - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) - a′ = similar(a) - map!(+, a′, a, a) - @test collect(a′) ≈ 2 * collect(a) - end - a = Eye(3) ⊗ Eye(2) - a′ = similar(a) - @test_throws ErrorException map!(+, a′, a, a) - - # map!(-, ...) - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) - a′ = similar(a) - map!(-, a′, a, a) - @test norm(collect(a′)) ≈ 0 - end - a = Eye(3) ⊗ Eye(2) - a′ = similar(a) - @test_throws ErrorException map!(-, a′, a, a) - - # map!(-, b, a) - for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) - a′ = similar(a) - map!(-, a′, a) - @test collect(a′) ≈ -collect(a) - end - a = Eye(3) ⊗ Eye(2) - a′ = similar(a) - @test_throws ErrorException map!(-, a′, a) - - # Eye ⊗ A - rng = StableRNG(123) - a = Eye(2) ⊗ randn(rng, 3, 3) - for f in MATRIX_FUNCTIONS - @eval begin - fa = $f($a) - @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) - @test fa.a isa Eye - end - end - - fa = inv(a) - @test collect(fa) ≈ inv(collect(a)) - @test fa.a isa Eye - - fa = pinv(a) - @test collect(fa) ≈ pinv(collect(a)) - @test fa.a isa Eye - - @test det(a) ≈ det(collect(a)) - - # A ⊗ Eye - rng = StableRNG(123) - a = randn(rng, 3, 3) ⊗ Eye(2) - for f in setdiff(MATRIX_FUNCTIONS, [:atanh]) - @eval begin - fa = $f($a) - @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) - @test fa.b isa Eye - end - end - - fa = inv(a) - @test collect(fa) ≈ inv(collect(a)) - @test fa.b isa Eye - - fa = pinv(a) - @test collect(fa) ≈ pinv(collect(a)) - @test fa.b isa Eye - - @test det(a) ≈ det(collect(a)) - - # Eye ⊗ Eye - a = Eye(2) ⊗ Eye(2) - for f in KroneckerArrays.MATRIX_FUNCTIONS - @eval begin - @test_throws ArgumentError $f($a) - end - end - - fa = inv(a) - @test fa == a - @test fa.a isa Eye - @test fa.b isa Eye - - fa = pinv(a) - @test fa == a - @test fa.a isa Eye - @test fa.b isa Eye - - @test det(a) ≈ det(collect(a)) ≈ 1 -end diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index a6f5ff2..5ff50cb 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -23,7 +23,7 @@ arrayts = (Array, JLArray) Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), ) a = dev(blocksparse(d, r, r)) - @test_broken sprint(show, a) + @test sprint(show, a) isa String @test sprint(show, MIME("text/plain"), a) isa String @test blocktype(a) === valtype(d) @test a isa BlockSparseMatrix{elt,valtype(d)} @@ -70,8 +70,8 @@ arrayts = (Array, JLArray) @test_broken a[Block.(1:2), Block(2)] end -@testset "BlockSparseArraysExt, SquareEyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in - arrayts, +@testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in + arrayts, elt in elts if arrayt == JLArray @@ -81,55 +81,56 @@ end end dev = adapt(arrayt) - r = blockrange([2 × 2, 3 × 3]) + r = @constinferred blockrange([2 × 2, 3 × 3]) d = Dict( Block(1, 1) => Eye{elt}(2, 2) ⊗ randn(elt, 2, 2), Block(2, 2) => Eye{elt}(3, 3) ⊗ randn(elt, 3, 3), ) - a = dev(blocksparse(d, r, r)) - @test_broken sprint(show, a) + a = @constinferred dev(blocksparse(d, r, r)) + @test sprint(show, a) == sprint(show, Array(a)) @test sprint(show, MIME("text/plain"), a) isa String - @test_broken blocktype(a) === valtype(d) - @test_broken a isa BlockSparseMatrix{elt,valtype(d)} - @test a[Block(1, 1)] == dev(d[Block(1, 1)]) - @test_broken a[Block(1, 1)] isa valtype(d) - @test a[Block(2, 2)] == dev(d[Block(2, 2)]) - @test_broken a[Block(2, 2)] isa valtype(d) - @test iszero(a[Block(2, 1)]) - @test a[Block(2, 1)] == dev(zeros(elt, 3, 2) ⊗ zeros(elt, 3, 2)) - @test_broken a[Block(2, 1)] isa valtype(d) + @test @constinferred(blocktype(a)) === valtype(d) + @test a isa BlockSparseMatrix{elt,valtype(d)} + @test @constinferred(a[Block(1, 1)]) == dev(d[Block(1, 1)]) + @test @constinferred(a[Block(1, 1)]) isa valtype(d) + @test @constinferred(a[Block(2, 2)]) == dev(d[Block(2, 2)]) + @test @constinferred(a[Block(2, 2)]) isa valtype(d) + @test @constinferred(iszero(a[Block(2, 1)])) + @test a[Block(2, 1)] == dev(Eye(3, 2) ⊗ zeros(elt, 3, 2)) + @test a[Block(2, 1)] isa valtype(d) @test iszero(a[Block(1, 2)]) - @test a[Block(1, 2)] == dev(zeros(elt, 2, 3) ⊗ zeros(elt, 2, 3)) - @test_broken a[Block(1, 2)] isa valtype(d) + @test a[Block(1, 2)] == dev(Eye(2, 3) ⊗ zeros(elt, 2, 3)) + @test a[Block(1, 2)] isa valtype(d) - b = a * a + b = @constinferred a * a @test typeof(b) === typeof(a) @test Array(b) ≈ Array(a) * Array(a) + # Type inference is broken for this operation. + # b = @constinferred a + a b = a + a @test typeof(b) === typeof(a) @test Array(b) ≈ Array(a) + Array(a) + # Type inference is broken for this operation. + # b = @constinferred 3a b = 3a @test typeof(b) === typeof(a) @test Array(b) ≈ 3Array(a) + # Type inference is broken for this operation. + # b = @constinferred a / 3 b = a / 3 @test typeof(b) === typeof(a) @test Array(b) ≈ Array(a) / 3 - @test norm(a) ≈ norm(Array(a)) + @test @constinferred(norm(a)) ≈ norm(Array(a)) - if arrayt == Array - @test Array(inv(a)) ≈ inv(Array(a)) - else - # Broken for JLArray, it seems like `inv` isn't - # type stable. - @test_broken inv(a) - end + b = @constinferred exp(a) + @test Array(b) ≈ exp(Array(a)) # Broken operations - # @test_broken exp(a) + @test_broken inv(a) @test_broken svd_compact(a) @test_broken a[Block.(1:2), Block(2)] end diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl new file mode 100644 index 0000000..001cda9 --- /dev/null +++ b/test/test_fillarrays.jl @@ -0,0 +1,192 @@ +using DerivableInterfaces: zero! +using FillArrays: Eye +using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗ +using LinearAlgebra: det, norm, pinv +using StableRNGs: StableRNG +using Test: @test, @testset + +@testset "FillArrays.Eye" begin + MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS + if VERSION < v"1.11-" + # `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11. + MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt]) + end + + a = Eye(2) ⊗ randn(3, 3) + @test size(a) == (6, 6) + @test a + a == Eye(2) ⊗ (2a.b) + @test 2a == Eye(2) ⊗ (2a.b) + @test a * a == Eye(2) ⊗ (a.b * a.b) + + a = randn(3, 3) ⊗ Eye(2) + @test size(a) == (6, 6) + @test a + a == (2a.a) ⊗ Eye(2) + @test 2a == (2a.a) ⊗ Eye(2) + @test a * a == (a.a * a.a) ⊗ Eye(2) + + # similar + a = Eye(2) ⊗ randn(3, 3) + for a′ in ( + similar(a), + similar(a, eltype(a)), + similar(a, axes(a)), + similar(a, eltype(a), axes(a)), + similar(typeof(a), axes(a)), + ) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} + @test a′.a === a.a + end + + a = Eye(2) ⊗ randn(3, 3) + for args in ((Float32,), (Float32, axes(a))) + a′ = similar(a, args...) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32,ndims(a)} + @test a′.a === Eye{Float32}(2) + end + + a = randn(3, 3) ⊗ Eye(2) + for a′ in ( + similar(a), + similar(a, eltype(a)), + similar(a, axes(a)), + similar(a, eltype(a), axes(a)), + similar(typeof(a), axes(a)), + ) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} + @test a′.b === a.b + end + + a = randn(3, 3) ⊗ Eye(2) + for args in ((Float32,), (Float32, axes(a))) + a′ = similar(a, args...) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32,ndims(a)} + @test a′.b === Eye{Float32}(2) + end + + a = Eye(3) ⊗ Eye(2) + for a′ in ( + similar(a), + similar(a, eltype(a)), + similar(a, axes(a)), + similar(a, eltype(a), axes(a)), + similar(typeof(a), axes(a)), + ) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} + @test a′.a === a.a + @test a′.b === a.b + end + + a = Eye(3) ⊗ Eye(2) + for args in ((Float32,), (Float32, axes(a))) + a′ = similar(a, args...) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32,ndims(a)} + @test a′.a === Eye{Float32}(3) + @test a′.b === Eye{Float32}(2) + end + + # DerivableInterfaces.zero! + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + zero!(a) + @test iszero(a) + end + a = Eye(3) ⊗ Eye(2) + @test_throws ArgumentError zero!(a) + + # map!(+, ...) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(+, a′, a, a) + @test collect(a′) ≈ 2 * collect(a) + end + a = Eye(3) ⊗ Eye(2) + a′ = similar(a) + @test_throws ErrorException map!(+, a′, a, a) + + # map!(-, ...) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(-, a′, a, a) + @test norm(collect(a′)) ≈ 0 + end + a = Eye(3) ⊗ Eye(2) + a′ = similar(a) + @test_throws ErrorException map!(-, a′, a, a) + + # map!(-, b, a) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(-, a′, a) + @test collect(a′) ≈ -collect(a) + end + a = Eye(3) ⊗ Eye(2) + a′ = similar(a) + @test_throws ErrorException map!(-, a′, a) + + # Eye ⊗ A + rng = StableRNG(123) + a = Eye(2) ⊗ randn(rng, 3, 3) + for f in MATRIX_FUNCTIONS + @eval begin + fa = $f($a) + @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) + @test fa.a isa Eye + end + end + + fa = inv(a) + @test collect(fa) ≈ inv(collect(a)) + @test fa.a isa Eye + + fa = pinv(a) + @test collect(fa) ≈ pinv(collect(a)) + @test fa.a isa Eye + + @test det(a) ≈ det(collect(a)) + + # A ⊗ Eye + rng = StableRNG(123) + a = randn(rng, 3, 3) ⊗ Eye(2) + for f in setdiff(MATRIX_FUNCTIONS, [:atanh]) + @eval begin + fa = $f($a) + @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) + @test fa.b isa Eye + end + end + + fa = inv(a) + @test collect(fa) ≈ inv(collect(a)) + @test fa.b isa Eye + + fa = pinv(a) + @test collect(fa) ≈ pinv(collect(a)) + @test fa.b isa Eye + + @test det(a) ≈ det(collect(a)) + + # Eye ⊗ Eye + a = Eye(2) ⊗ Eye(2) + for f in KroneckerArrays.MATRIX_FUNCTIONS + @eval begin + @test_throws ArgumentError $f($a) + end + end + + fa = inv(a) + @test fa == a + @test fa.a isa Eye + @test fa.b isa Eye + + fa = pinv(a) + @test fa == a + @test fa.a isa Eye + @test fa.b isa Eye + + @test det(a) ≈ det(collect(a)) ≈ 1 +end diff --git a/test/test_fillarrays_matrixalgebrakit.jl b/test/test_fillarrays_matrixalgebrakit.jl new file mode 100644 index 0000000..bbc08d8 --- /dev/null +++ b/test/test_fillarrays_matrixalgebrakit.jl @@ -0,0 +1,275 @@ +using FillArrays: Eye, Ones +using KroneckerArrays: ⊗, arguments +using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm +using MatrixAlgebraKit: + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, + left_null, + left_orth, + left_polar, + lq_compact, + lq_full, + qr_compact, + qr_full, + right_null, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_trunc, + svd_vals +using Test: @test, @test_throws, @testset +using TestExtras: @constinferred + +herm(a) = parent(hermitianpart(a)) + +@testset "MatrixAlgebraKit + Eye" begin + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + d, v = @constinferred eig_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye{complex(elt)} + @test arguments(v, 1) isa Eye{complex(elt)} + + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) + d, v = @constinferred eig_full(a) + @test a * v ≈ v * d + @test arguments(d, 2) isa Eye{complex(elt)} + @test arguments(v, 2) isa Eye{complex(elt)} + + a = Eye{elt}(3) ⊗ Eye{elt}(3) + d, v = @constinferred eig_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye{complex(elt)} + @test arguments(d, 2) isa Eye{complex(elt)} + @test arguments(v, 1) isa Eye{complex(elt)} + @test arguments(v, 2) isa Eye{complex(elt)} + end + + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) + d, v = @constinferred eigh_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) + d, v = @constinferred eigh_full(a) + @test a * v ≈ v * d + @test arguments(d, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + + a = Eye{elt}(3) ⊗ Eye{elt}(3) + d, v = @constinferred eigh_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye{real(elt)} + @test arguments(d, 2) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + @test arguments(v, 2) isa Eye{elt} + end + + for f in (eig_trunc, eigh_trunc) + a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye + @test arguments(v, 1) isa Eye + @test size(d) == (6, 6) + @test size(v) == (9, 6) + + a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 2) isa Eye + @test arguments(v, 2) isa Eye + @test size(d) == (6, 6) + @test size(v) == (9, 6) + + a = Eye(3) ⊗ Eye(3) + @test_throws ArgumentError f(a) + end + + for f in (eig_vals, eigh_vals) + a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) + d = @constinferred f(a) + d′ = f(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) ≈ f(arguments(a, 2)) + + a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) + d = @constinferred f(a) + d′ = f(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 2) isa Ones + @test arguments(d, 1) ≈ f(arguments(a, 1)) + + a = Eye(3) ⊗ Eye(3) + d = @constinferred f(a) + @test d == Ones(3) ⊗ Ones(3) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) isa Ones + end + + for f in ( + left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar + ) + a = Eye(3) ⊗ randn(3, 3) + x, y = @constinferred f(a) + @test x * y ≈ a + @test arguments(x, 1) isa Eye + @test arguments(y, 1) isa Eye + + a = randn(3, 3) ⊗ Eye(3) + x, y = @constinferred f(a) + @test x * y ≈ a + @test arguments(x, 2) isa Eye + @test arguments(y, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + x, y = f(a) + @test x * y ≈ a + @test arguments(x, 1) isa Eye + @test arguments(y, 1) isa Eye + @test arguments(x, 2) isa Eye + @test arguments(y, 2) isa Eye + end + + for f in (svd_compact, svd_full) + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + u, s, v = @constinferred f(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 1) isa Eye{elt} + @test arguments(s, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + + a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + u, s, v = @constinferred f(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 2) isa Eye{elt} + @test arguments(s, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + + a = Eye{elt}(3) ⊗ Eye{elt}(3) + u, s, v = @constinferred f(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 1) isa Eye{elt} + @test arguments(s, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + @test arguments(u, 2) isa Eye{elt} + @test arguments(s, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + end + end + + # svd_trunc + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 1) isa Eye{elt} + @test arguments(s, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end + + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 2) isa Eye{elt} + @test arguments(s, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end + + a = Eye(3) ⊗ Eye(3) + @test_throws ArgumentError svd_trunc(a) + + # svd_vals + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 1) isa Ones{real(elt)} + @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) + end + + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 2) isa Ones{real(elt)} + @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) + end + + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ Eye{elt}(3) + d = @constinferred svd_vals(a) + @test d == Ones(3) ⊗ Ones(3) + @test arguments(d, 1) isa Ones{real(elt)} + @test arguments(d, 2) isa Ones{real(elt)} + end + + # left_null + a = Eye(3) ⊗ randn(3, 3) + n = @constinferred left_null(a) + @test norm(n' * a) ≈ 0 + @test arguments(n, 1) isa Eye + + a = randn(3, 3) ⊗ Eye(3) + n = @constinferred left_null(a) + @test norm(n' * a) ≈ 0 + @test arguments(n, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + @test_throws MethodError left_null(a) + + # right_null + a = Eye(3) ⊗ randn(3, 3) + n = @constinferred right_null(a) + @test norm(a * n') ≈ 0 + @test arguments(n, 1) isa Eye + + a = randn(3, 3) ⊗ Eye(3) + n = @constinferred right_null(a) + @test norm(a * n') ≈ 0 + @test arguments(n, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + @test_throws MethodError right_null(a) +end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index e1b96e5..8bf4e3e 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,4 +1,3 @@ -using FillArrays: Eye, Ones using KroneckerArrays: ⊗, arguments using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm using MatrixAlgebraKit: @@ -128,250 +127,3 @@ herm(a) = parent(hermitianpart(a)) s = svd_vals(a) @test s ≈ diag(svd_compact(a)[2]) end - -@testset "MatrixAlgebraKit + Eye" begin - for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - d, v = @constinferred eig_full(a) - @test a * v ≈ v * d - @test arguments(d, 1) isa Eye{complex(elt)} - @test arguments(v, 1) isa Eye{complex(elt)} - - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) - d, v = @constinferred eig_full(a) - @test a * v ≈ v * d - @test arguments(d, 2) isa Eye{complex(elt)} - @test arguments(v, 2) isa Eye{complex(elt)} - - a = Eye{elt}(3) ⊗ Eye{elt}(3) - d, v = @constinferred eig_full(a) - @test a * v ≈ v * d - @test arguments(d, 1) isa Eye{complex(elt)} - @test arguments(d, 2) isa Eye{complex(elt)} - @test arguments(v, 1) isa Eye{complex(elt)} - @test arguments(v, 2) isa Eye{complex(elt)} - end - - for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) - d, v = @constinferred eigh_full(a) - @test a * v ≈ v * d - @test arguments(d, 1) isa Eye{real(elt)} - @test arguments(v, 1) isa Eye{elt} - - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) - d, v = @constinferred eigh_full(a) - @test a * v ≈ v * d - @test arguments(d, 2) isa Eye{real(elt)} - @test arguments(v, 2) isa Eye{elt} - - a = Eye{elt}(3) ⊗ Eye{elt}(3) - d, v = @constinferred eigh_full(a) - @test a * v ≈ v * d - @test arguments(d, 1) isa Eye{real(elt)} - @test arguments(d, 2) isa Eye{real(elt)} - @test arguments(v, 1) isa Eye{elt} - @test arguments(v, 2) isa Eye{elt} - end - - for f in (eig_trunc, eigh_trunc) - a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) - d, v = f(a; trunc=(; maxrank=7)) - @test a * v ≈ v * d - @test arguments(d, 1) isa Eye - @test arguments(v, 1) isa Eye - @test size(d) == (6, 6) - @test size(v) == (9, 6) - - a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) - d, v = f(a; trunc=(; maxrank=7)) - @test a * v ≈ v * d - @test arguments(d, 2) isa Eye - @test arguments(v, 2) isa Eye - @test size(d) == (6, 6) - @test size(v) == (9, 6) - - a = Eye(3) ⊗ Eye(3) - @test_throws ArgumentError f(a) - end - - for f in (eig_vals, eigh_vals) - a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) - d = @constinferred f(a) - d′ = f(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) ≈ f(arguments(a, 2)) - - a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) - d = @constinferred f(a) - d′ = f(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 2) isa Ones - @test arguments(d, 1) ≈ f(arguments(a, 1)) - - a = Eye(3) ⊗ Eye(3) - d = @constinferred f(a) - @test d == Ones(3) ⊗ Ones(3) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) isa Ones - end - - for f in ( - left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar - ) - a = Eye(3) ⊗ randn(3, 3) - x, y = @constinferred f(a) - @test x * y ≈ a - @test arguments(x, 1) isa Eye - @test arguments(y, 1) isa Eye - - a = randn(3, 3) ⊗ Eye(3) - x, y = @constinferred f(a) - @test x * y ≈ a - @test arguments(x, 2) isa Eye - @test arguments(y, 2) isa Eye - - a = Eye(3) ⊗ Eye(3) - x, y = f(a) - @test x * y ≈ a - @test arguments(x, 1) isa Eye - @test arguments(y, 1) isa Eye - @test arguments(x, 2) isa Eye - @test arguments(y, 2) isa Eye - end - - for f in (svd_compact, svd_full) - for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - u, s, v = @constinferred f(a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test arguments(u, 1) isa Eye{elt} - @test arguments(s, 1) isa Eye{real(elt)} - @test arguments(v, 1) isa Eye{elt} - - a = randn(elt, 3, 3) ⊗ Eye{elt}(3) - u, s, v = @constinferred f(a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test arguments(u, 2) isa Eye{elt} - @test arguments(s, 2) isa Eye{real(elt)} - @test arguments(v, 2) isa Eye{elt} - - a = Eye{elt}(3) ⊗ Eye{elt}(3) - u, s, v = @constinferred f(a) - @test u * s * v ≈ a - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - @test arguments(u, 1) isa Eye{elt} - @test arguments(s, 1) isa Eye{real(elt)} - @test arguments(v, 1) isa Eye{elt} - @test arguments(u, 2) isa Eye{elt} - @test arguments(s, 2) isa Eye{real(elt)} - @test arguments(v, 2) isa Eye{elt} - end - end - - # svd_trunc - for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - # TODO: Type inference is broken for `svd_trunc`, - # look into fixing it. - # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 1) isa Eye{elt} - @test arguments(s, 1) isa Eye{real(elt)} - @test arguments(v, 1) isa Eye{elt} - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) - end - - for elt in (Float32, ComplexF32) - a = randn(elt, 3, 3) ⊗ Eye{elt}(3) - # TODO: Type inference is broken for `svd_trunc`, - # look into fixing it. - # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 2) isa Eye{elt} - @test arguments(s, 2) isa Eye{real(elt)} - @test arguments(v, 2) isa Eye{elt} - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) - end - - a = Eye(3) ⊗ Eye(3) - @test_throws ArgumentError svd_trunc(a) - - # svd_vals - for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - d = @constinferred svd_vals(a) - d′ = svd_vals(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 1) isa Ones{real(elt)} - @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) - end - - for elt in (Float32, ComplexF32) - a = randn(elt, 3, 3) ⊗ Eye{elt}(3) - d = @constinferred svd_vals(a) - d′ = svd_vals(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 2) isa Ones{real(elt)} - @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) - end - - for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ Eye{elt}(3) - d = @constinferred svd_vals(a) - @test d == Ones(3) ⊗ Ones(3) - @test arguments(d, 1) isa Ones{real(elt)} - @test arguments(d, 2) isa Ones{real(elt)} - end - - # left_null - a = Eye(3) ⊗ randn(3, 3) - n = @constinferred left_null(a) - @test norm(n' * a) ≈ 0 - @test arguments(n, 1) isa Eye - - a = randn(3, 3) ⊗ Eye(3) - n = @constinferred left_null(a) - @test norm(n' * a) ≈ 0 - @test arguments(n, 2) isa Eye - - a = Eye(3) ⊗ Eye(3) - @test_throws MethodError left_null(a) - - # right_null - a = Eye(3) ⊗ randn(3, 3) - n = @constinferred right_null(a) - @test norm(a * n') ≈ 0 - @test arguments(n, 1) isa Eye - - a = randn(3, 3) ⊗ Eye(3) - n = @constinferred right_null(a) - @test norm(a * n') ≈ 0 - @test arguments(n, 2) isa Eye - - a = Eye(3) ⊗ Eye(3) - @test_throws MethodError right_null(a) -end