diff --git a/Project.toml b/Project.toml index 196b1f5..6562906 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.12" +version = "0.1.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 597872c..4552a2f 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -9,5 +9,6 @@ include("matrixalgebrakit.jl") include("fillarrays/kroneckerarray.jl") include("fillarrays/linearalgebra.jl") include("fillarrays/matrixalgebrakit.jl") +include("fillarrays/matrixalgebrakit_truncate.jl") end diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 1030602..8dcd431 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -11,6 +11,23 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T, # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a +# Like `similar` but preserves `Eye`. +function _similar(a::AbstractArray, elt::Type, ax::Tuple) + return similar(a, elt, ax) +end +function _similar(A::Type{<:AbstractArray}, ax::Tuple) + return similar(A, ax) +end +function _similar(a::AbstractArray, ax::Tuple) + return _similar(a, eltype(a), ax) +end +function _similar(a::AbstractArray, elt::Type) + return _similar(a, elt, axes(a)) +end +function _similar(a::AbstractArray) + return _similar(a, eltype(a), axes(a)) +end + # Like `similar` but preserves `Eye`. function _similar(a::Eye, elt::Type, axs::NTuple{2,AbstractUnitRange}) return Eye{elt}(axs) diff --git a/src/fillarrays/linearalgebra.jl b/src/fillarrays/linearalgebra.jl index ed2cd6a..d514467 100644 --- a/src/fillarrays/linearalgebra.jl +++ b/src/fillarrays/linearalgebra.jl @@ -7,6 +7,9 @@ end function _mul(a::Eye, b::Eye) check_mul_axes(a, b) + (size(a, 1) > size(a, 2)) && + (size(b, 1) < size(b, 2)) && + error("This multiplication leads to a projector.") T = promote_type(eltype(a), eltype(b)) return Eye{T}((axes(a, 1), axes(b, 2))) end diff --git a/src/fillarrays/matrixalgebrakit.jl b/src/fillarrays/matrixalgebrakit.jl index 1f82d24..093760b 100644 --- a/src/fillarrays/matrixalgebrakit.jl +++ b/src/fillarrays/matrixalgebrakit.jl @@ -1,17 +1,47 @@ -#################################################################################### -# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and -# `A ⊗ Eye(n)` where `A`. -# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34 -# is merged. +function infimum(r1::AbstractRange, r2::AbstractUnitRange) + Base.require_one_based_indexing(r1, r2) + if length(r1) ≤ length(r2) + return r1 + else + return r2 + end +end +function supremum(r1::AbstractRange, r2::AbstractUnitRange) + Base.require_one_based_indexing(r1, r2) + if length(r1) ≥ length(r2) + return r1 + else + return r2 + end +end + +# Allow customization for `Eye`. +_diagview(a::Eye) = parent(a) -struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm +function _copy_input(f::F, a::Eye) where {F} + return a +end + +struct EyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm kwargs::KWargs end -SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...)) +EyeAlgorithm(; kwargs...) = EyeAlgorithm((; kwargs...)) -# Defined to avoid type piracy. -_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a) -_copy_input_squareeye(f::F, a::SquareEye) where {F} = a +for f in [ + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, +] + _f = Symbol(:_, f) + @eval begin + function $_f(A::Type{<:Eye}; kwargs...) + return EyeAlgorithm(; kwargs...) + end + end +end for f in [ :eig_full, @@ -30,269 +60,107 @@ for f in [ :right_polar, :svd_compact, :svd_full, + :svd_vals, ] - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.copy_input(::typeof($f), a::$T) - return _copy_input_squareeye($f, a.a) ⊗ _copy_input_squareeye($f, a.b) - end - end - end -end - -for f in [ - :default_eig_algorithm, - :default_eigh_algorithm, - :default_lq_algorithm, - :default_qr_algorithm, - :default_polar_algorithm, - :default_svd_algorithm, -] - f′ = Symbol("_", f, "_squareeye") + f! = Symbol(f, "!") @eval begin - $f′(a; kwargs...) = $f(a; kwargs...) - $f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...) - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...) - A1, A2 = argument_types(A) - return KroneckerAlgorithm( - $f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...) - ) - end + function MatrixAlgebraKit.$f!(a::Eye, F, ::EyeAlgorithm) + return F end end end -# Defined to avoid type piracy. -_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a) -_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg) +_complex(a::AbstractArray) = complex(a) +_complex(a::Eye{<:Complex}) = a +_complex(a::Eye) = _similar(a, complex(eltype(a))) +_real(a::AbstractArray) = real(a) +_real(a::Eye{<:Real}) = a +_real(a::Eye) = _similar(a, real(eltype(a))) -for f in [:left_null!, :right_null!] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = a - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = a - end +# Implementations of `Eye` factorizations are doing in `initialize_output` +# so they can be used in KroneckerArray factorizations. +function _initialize_output(::typeof(eig_full!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + return _complex.((a, a)) end -for f in [ - :qr_compact!, - :qr_full!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :right_orth!, - :right_polar!, -] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a) - end +function _initialize_output(::typeof(eigh_full!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + return (_real(a), a) end -_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a)) -_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a)) -_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a) -_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a) -for f in [:svd_compact!, :svd_full!] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a) - end -end - -for f in [ - :eig_full!, - :eigh_full!, - :qr_compact!, - :qr_full!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :right_orth!, - :right_polar!, - :svd_compact!, - :svd_full!, -] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) - $f′(a, F, alg::SquareEyeAlgorithm) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) - return _initialize_output_squareeye($f, a.a) .⊗ - _initialize_output_squareeye($f, a.b) - end - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::$T, alg::KroneckerAlgorithm - ) - return _initialize_output_squareeye($f, a.a, alg.a) .⊗ - _initialize_output_squareeye($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f( - a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - $f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) - $f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) - return F - end - end - end -end - -for f in [:left_null!, :right_null!] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a, F; kwargs...) = $f(a, F; kwargs...) - $f′(a::SquareEye, F) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) - return _initialize_output_squareeye($f, a.a) ⊗ _initialize_output_squareeye($f, a.b) - end - function MatrixAlgebraKit.$f(a::$T, F; kwargs1=(;), kwargs2=(;), kwargs...) - $f′(a.a, F.a; kwargs..., kwargs1...) - $f′(a.b, F.b; kwargs..., kwargs2...) - return F - end - end - end +function _initialize_output(::typeof(eig_vals!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + # TODO: Use `_diagview`/`_diag`. + return _complex(parent(a)) end - -function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::SquareEyeSquareEye) - return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) +function _initialize_output(::typeof(eigh_vals!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + # TODO: Use `_diagview`/`_diag`. + return _real(parent(a)) end -function MatrixAlgebraKit.left_null!( - a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... -) - return throw(MethodError(left_null!, (a, F))) +function _initialize_output(::typeof(svd_compact!), a::Eye, ::EyeAlgorithm) + ax_s = (infimum(axes(a)...), infimum(reverse(axes(a))...)) + ax_u = (axes(a, 1), ax_s[2]) + ax_v = (ax_s[1], axes(a, 2)) + Tr = real(eltype(a)) + return (_similar(a, ax_u), _similar(a, Tr, ax_s), _similar(a, ax_v)) end - -function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::SquareEyeSquareEye) - return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) +function _initialize_output(::typeof(svd_full!), a::Eye, ::EyeAlgorithm) + ax_s = axes(a) + ax_u = (axes(a, 1), axes(a, 1)) + ax_v = (axes(a, 2), axes(a, 2)) + Tr = real(eltype(a)) + return (_similar(a, ax_u), _similar(a, Tr, ax_s), _similar(a, ax_v)) end -function MatrixAlgebraKit.right_null!( - a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... -) - return throw(MethodError(right_null!, (a, F))) +function _initialize_output(::typeof(svd_vals!), a::Eye, ::EyeAlgorithm) + # TODO: Use `_diagview`/`_diag`. + return _real(parent(a)) end -_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a) -_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a) -for f in [:eigh_vals!, svd_vals!] +for f in [:left_polar!, :right_polar!, qr_compact!, lq_compact!] @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a)) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a)) + function _initialize_output(::typeof($f), a::Eye, ::EyeAlgorithm) + ax = infimum(axes(a)...) + ax_x = (axes(a, 1), ax) + ax_y = (ax, axes(a, 2)) + return (_similar(a, ax_x), _similar(a, ax_y)) + end end end -for f in [:eig_vals!, :eigh_vals!, :svd_vals!] - f′ = Symbol("_", f, "_squareeye") +for f in [qr_full!, lq_full!] @eval begin - $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) - $f′(a, F, alg::SquareEyeAlgorithm) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::$T, alg::KroneckerAlgorithm - ) - return _initialize_output_squareeye($f, a.a, alg.a) ⊗ - _initialize_output_squareeye($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f( - a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - $f′(a.a, F.a, alg.a; kwargs..., kwargs1...) - $f′(a.b, F.b, alg.b; kwargs..., kwargs2...) - return F - end + function _initialize_output(::typeof($f), a::Eye, ::EyeAlgorithm) + ax = supremum(axes(a)...) + ax_x = (axes(a, 1), ax) + ax_y = (ax, axes(a, 2)) + return (_similar(a, ax_x), _similar(a, ax_y)) end end end -using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! - -struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy - strategy::T -end - -# Avoid instantiating the identity. -function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2}) - return a.a ⊗ a.b[I[1].b, I[2].b] -end -function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) - return a.a[I[1].a, I[2].a] ⊗ a.b -end -function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2}) - return a -end - -using FillArrays: OnesVector -const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} -const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} -const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} - -function MatrixAlgebraKit.findtruncated( - values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.a, prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> x.a == i, prods) == length(values.a) - end - return (:) × I_data -end -function MatrixAlgebraKit.findtruncated( - values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.b, prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> x.b == i, prods) == length(values.b) +for f in [:left_orth!, :right_orth!] + @eval begin + function _initialize_output(::typeof($f), a::Eye) + ax = infimum(axes(a)...) + ax_x = (axes(a, 1), ax) + ax_y = (ax, axes(a, 2)) + return (_similar(a, ax_x), _similar(a, ax_y)) + end end - return I_data × (:) -end -function MatrixAlgebraKit.findtruncated( - values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy -) - return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) end -for f in [:eig_trunc!, :eigh_trunc!] +for f in [:left_null!, :right_null!] + _f = Symbol(:_, f) @eval begin - function MatrixAlgebraKit.truncate!( - ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy - ) - return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) + function _initialize_output(::typeof($f), a::Eye) + return a end - function MatrixAlgebraKit.truncate!( - ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy - ) - I = findtruncated(diagview(D), strategy) - return (D[I, I], V[(:) × (:), I]) + function $_f(a::Eye, F) + return F end - end -end -function MatrixAlgebraKit.truncate!( - f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy -) - return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) -end -function MatrixAlgebraKit.truncate!( - ::typeof(svd_trunc!), - (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, - strategy::KroneckerTruncationStrategy, -) - I = findtruncated(diagview(S), strategy) - return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) + function MatrixAlgebraKit.$f(a::EyeEye, F; kwargs...) + return throw(MethodError($f, (a, F))) + end + end end diff --git a/src/fillarrays/matrixalgebrakit_truncate.jl b/src/fillarrays/matrixalgebrakit_truncate.jl new file mode 100644 index 0000000..ae50f27 --- /dev/null +++ b/src/fillarrays/matrixalgebrakit_truncate.jl @@ -0,0 +1,81 @@ +using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! + +struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy + strategy::T +end + +# Avoid instantiating the identity. +function Base.getindex(a::EyeKronecker, I::Vararg{CartesianProduct{Colon},2}) + return a.a ⊗ a.b[I[1].b, I[2].b] +end +function Base.getindex(a::KroneckerEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) + return a.a[I[1].a, I[2].a] ⊗ a.b +end +function Base.getindex(a::EyeEye, I::Vararg{CartesianProduct{Colon,Colon},2}) + return a +end + +using FillArrays: OnesVector +const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} +const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} +const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} + +function MatrixAlgebraKit.findtruncated( + values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.a, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.a == i, prods) == length(values.a) + end + return (:) × I_data +end +function MatrixAlgebraKit.findtruncated( + values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.b, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.b == i, prods) == length(values.b) + end + return I_data × (:) +end +function MatrixAlgebraKit.findtruncated( + values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy +) + return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) +end + +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy + ) + return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) + end + function MatrixAlgebraKit.truncate!( + ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy + ) + I = findtruncated(diagview(D), strategy) + return (D[I, I], V[(:) × (:), I]) + end + end +end + +function MatrixAlgebraKit.truncate!( + f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy +) + return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) +end +function MatrixAlgebraKit.truncate!( + ::typeof(svd_trunc!), + (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, + strategy::KroneckerTruncationStrategy, +) + I = findtruncated(diagview(S), strategy) + return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) +end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 5cf8b77..7e88608 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -32,8 +32,10 @@ using MatrixAlgebraKit: truncate! using MatrixAlgebraKit: MatrixAlgebraKit, diagview +# Allow customization for `Eye`. +_diagview(a::AbstractMatrix) = diagview(a) function MatrixAlgebraKit.diagview(a::KroneckerMatrix) - return diagview(a.a) ⊗ diagview(a.b) + return _diagview(a.a) ⊗ _diagview(a.b) end struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm @@ -60,6 +62,10 @@ using MatrixAlgebraKit: svd_compact, svd_full +function _copy_input(f::F, a::AbstractMatrix) where {F} + return copy_input(f, a) +end + for f in [ :eig_full, :eigh_full, @@ -74,7 +80,7 @@ for f in [ ] @eval begin function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) - return copy_input($f, a.a) ⊗ copy_input($f, a.b) + return _copy_input($f, a.a) ⊗ _copy_input($f, a.b) end end end @@ -87,13 +93,17 @@ for f in [ :default_polar_algorithm, :default_svd_algorithm, ] + _f = Symbol(:_, f) @eval begin + function $_f(A::Type{<:AbstractMatrix}; kwargs...) + return $f(A; kwargs...) + end function MatrixAlgebraKit.$f( A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs... ) A1, A2 = argument_types(A) return KroneckerAlgorithm( - $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) + $_f(A1; kwargs..., kwargs1...), $_f(A2; kwargs..., kwargs2...) ) end end @@ -112,6 +122,12 @@ function MatrixAlgebraKit.default_algorithm( return default_qr_algorithm(A; kwargs...) end +# Allows overloading while avoiding type piracy. +function _initialize_output(f::F, a::AbstractMatrix, alg::AbstractAlgorithm) where {F} + return initialize_output(f, a, alg) +end +_initialize_output(f::F, a::AbstractMatrix) where {F} = initialize_output(f, a) + for f in [ :eig_full!, :eigh_full!, @@ -128,7 +144,7 @@ for f in [ function MatrixAlgebraKit.initialize_output( ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm ) - return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) + return _initialize_output($f, a.a, alg.a) .⊗ _initialize_output($f, a.b, alg.b) end function MatrixAlgebraKit.$f( a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... @@ -145,7 +161,7 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!] function MatrixAlgebraKit.initialize_output( ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm ) - return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b) + return _initialize_output($f, a.a, alg.a) ⊗ _initialize_output($f, a.b, alg.b) end function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) $f(a.a, F.a, alg.a) @@ -158,19 +174,23 @@ end for f in [:left_orth!, :right_orth!] @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) - return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) + return _initialize_output($f, a.a) .⊗ _initialize_output($f, a.b) end end end for f in [:left_null!, :right_null!] + _f = Symbol(:_, f) @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) - return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) + return _initialize_output($f, a.a) ⊗ _initialize_output($f, a.b) + end + function $_f(a::AbstractMatrix, F; kwargs...) + return $f(a, F; kwargs...) end function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...) - $f(a.a, F.a; kwargs..., kwargs1...) - $f(a.b, F.b; kwargs..., kwargs2...) + $_f(a.a, F.a; kwargs..., kwargs1...) + $_f(a.b, F.b; kwargs..., kwargs2...) return F end end diff --git a/test/test_fillarrays_matrixalgebrakit.jl b/test/test_fillarrays_matrixalgebrakit.jl index bbc08d8..d785bd6 100644 --- a/test/test_fillarrays_matrixalgebrakit.jl +++ b/test/test_fillarrays_matrixalgebrakit.jl @@ -29,19 +29,19 @@ herm(a) = parent(hermitianpart(a)) @testset "MatrixAlgebraKit + Eye" begin for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{complex(elt)} @test arguments(v, 1) isa Eye{complex(elt)} - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d @test arguments(d, 2) isa Eye{complex(elt)} @test arguments(v, 2) isa Eye{complex(elt)} - a = Eye{elt}(3) ⊗ Eye{elt}(3) + a = Eye{elt}(3, 3) ⊗ Eye{elt}(3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{complex(elt)} @@ -51,20 +51,20 @@ herm(a) = parent(hermitianpart(a)) end for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) - d, v = @constinferred eigh_full(a) + a = Eye{elt}(3, 3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) + d, v = @constinferred eigh_full($a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{real(elt)} @test arguments(v, 1) isa Eye{elt} - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) - d, v = @constinferred eigh_full(a) + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3, 3) + d, v = @constinferred eigh_full($a) @test a * v ≈ v * d @test arguments(d, 2) isa Eye{real(elt)} @test arguments(v, 2) isa Eye{elt} - a = Eye{elt}(3) ⊗ Eye{elt}(3) - d, v = @constinferred eigh_full(a) + a = Eye{elt}(3, 3) ⊗ Eye{elt}(3, 3) + d, v = @constinferred eigh_full($a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{real(elt)} @test arguments(d, 2) isa Eye{real(elt)} @@ -116,22 +116,22 @@ herm(a) = parent(hermitianpart(a)) end for f in ( - left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar + left_orth, right_orth, left_polar, right_polar, qr_compact, lq_compact, qr_full, lq_full ) - a = Eye(3) ⊗ randn(3, 3) - x, y = @constinferred f(a) + a = Eye(3, 3) ⊗ randn(3, 3) + x, y = @constinferred f($a) @test x * y ≈ a @test arguments(x, 1) isa Eye @test arguments(y, 1) isa Eye - a = randn(3, 3) ⊗ Eye(3) - x, y = @constinferred f(a) + a = randn(3, 3) ⊗ Eye(3, 3) + x, y = @constinferred f($a) @test x * y ≈ a @test arguments(x, 2) isa Eye @test arguments(y, 2) isa Eye - a = Eye(3) ⊗ Eye(3) - x, y = f(a) + a = Eye(3, 3) ⊗ Eye(3, 3) + x, y = @constinferred f($a) @test x * y ≈ a @test arguments(x, 1) isa Eye @test arguments(y, 1) isa Eye @@ -141,8 +141,8 @@ herm(a) = parent(hermitianpart(a)) for f in (svd_compact, svd_full) for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - u, s, v = @constinferred f(a) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) + u, s, v = @constinferred f($a) @test u * s * v ≈ a @test eltype(u) === elt @test eltype(s) === real(elt) @@ -151,8 +151,8 @@ herm(a) = parent(hermitianpart(a)) @test arguments(s, 1) isa Eye{real(elt)} @test arguments(v, 1) isa Eye{elt} - a = randn(elt, 3, 3) ⊗ Eye{elt}(3) - u, s, v = @constinferred f(a) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3, 3) + u, s, v = @constinferred f($a) @test u * s * v ≈ a @test eltype(u) === elt @test eltype(s) === real(elt) @@ -161,8 +161,8 @@ herm(a) = parent(hermitianpart(a)) @test arguments(s, 2) isa Eye{real(elt)} @test arguments(v, 2) isa Eye{elt} - a = Eye{elt}(3) ⊗ Eye{elt}(3) - u, s, v = @constinferred f(a) + a = Eye{elt}(3, 3) ⊗ Eye{elt}(3, 3) + u, s, v = @constinferred f($a) @test u * s * v ≈ a @test eltype(u) === elt @test eltype(s) === real(elt) @@ -178,7 +178,7 @@ herm(a) = parent(hermitianpart(a)) # svd_trunc for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) # TODO: Type inference is broken for `svd_trunc`, # look into fixing it. # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) @@ -197,7 +197,7 @@ herm(a) = parent(hermitianpart(a)) end for elt in (Float32, ComplexF32) - a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3, 3) # TODO: Type inference is broken for `svd_trunc`, # look into fixing it. # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) @@ -215,12 +215,12 @@ herm(a) = parent(hermitianpart(a)) @test size(v) == (6, 9) end - a = Eye(3) ⊗ Eye(3) + a = Eye(3, 3) ⊗ Eye(3, 3) @test_throws ArgumentError svd_trunc(a) # svd_vals for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) d = @constinferred svd_vals(a) d′ = svd_vals(Matrix(a)) @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) @@ -246,12 +246,12 @@ herm(a) = parent(hermitianpart(a)) end # left_null - a = Eye(3) ⊗ randn(3, 3) + a = Eye(3, 3) ⊗ randn(3, 3) n = @constinferred left_null(a) @test norm(n' * a) ≈ 0 @test arguments(n, 1) isa Eye - a = randn(3, 3) ⊗ Eye(3) + a = randn(3, 3) ⊗ Eye(3, 3) n = @constinferred left_null(a) @test norm(n' * a) ≈ 0 @test arguments(n, 2) isa Eye