From e392cc8b449436a3c40c926d9ebc81d5bbb5fe6c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 15 Jun 2025 17:48:30 -0400 Subject: [PATCH 1/5] More factorizations --- src/fillarrays/kroneckerarray.jl | 17 ++ src/fillarrays/linearalgebra.jl | 3 + src/fillarrays/matrixalgebrakit.jl | 345 +++++++---------------- src/matrixalgebrakit.jl | 26 +- test/test_fillarrays_matrixalgebrakit.jl | 202 ++++++------- 5 files changed, 238 insertions(+), 355 deletions(-) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 1030602..8dcd431 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -11,6 +11,23 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T, # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a +# Like `similar` but preserves `Eye`. +function _similar(a::AbstractArray, elt::Type, ax::Tuple) + return similar(a, elt, ax) +end +function _similar(A::Type{<:AbstractArray}, ax::Tuple) + return similar(A, ax) +end +function _similar(a::AbstractArray, ax::Tuple) + return _similar(a, eltype(a), ax) +end +function _similar(a::AbstractArray, elt::Type) + return _similar(a, elt, axes(a)) +end +function _similar(a::AbstractArray) + return _similar(a, eltype(a), axes(a)) +end + # Like `similar` but preserves `Eye`. function _similar(a::Eye, elt::Type, axs::NTuple{2,AbstractUnitRange}) return Eye{elt}(axs) diff --git a/src/fillarrays/linearalgebra.jl b/src/fillarrays/linearalgebra.jl index ed2cd6a..d514467 100644 --- a/src/fillarrays/linearalgebra.jl +++ b/src/fillarrays/linearalgebra.jl @@ -7,6 +7,9 @@ end function _mul(a::Eye, b::Eye) check_mul_axes(a, b) + (size(a, 1) > size(a, 2)) && + (size(b, 1) < size(b, 2)) && + error("This multiplication leads to a projector.") T = promote_type(eltype(a), eltype(b)) return Eye{T}((axes(a, 1), axes(b, 2))) end diff --git a/src/fillarrays/matrixalgebrakit.jl b/src/fillarrays/matrixalgebrakit.jl index 1f82d24..7a1f138 100644 --- a/src/fillarrays/matrixalgebrakit.jl +++ b/src/fillarrays/matrixalgebrakit.jl @@ -1,17 +1,44 @@ -#################################################################################### -# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and -# `A ⊗ Eye(n)` where `A`. -# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34 -# is merged. +function infimum(r1::AbstractRange, r2::AbstractUnitRange) + Base.require_one_based_indexing(r1, r2) + if length(r1) ≤ length(r2) + return r1 + else + return r2 + end +end +function supremum(r1::AbstractRange, r2::AbstractUnitRange) + Base.require_one_based_indexing(r1, r2) + if length(r1) ≥ length(r2) + return r1 + else + return r2 + end +end + +function _copy_input(f::F, a::Eye) where {F} + return a +end -struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm +struct EyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm kwargs::KWargs end -SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...)) +EyeAlgorithm(; kwargs...) = EyeAlgorithm((; kwargs...)) -# Defined to avoid type piracy. -_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a) -_copy_input_squareeye(f::F, a::SquareEye) where {F} = a +for f in [ + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, +] + _f = Symbol(:_, f) + @eval begin + function $_f(A::Type{<:Eye}; kwargs...) + return EyeAlgorithm(; kwargs...) + end + end +end for f in [ :eig_full, @@ -30,269 +57,91 @@ for f in [ :right_polar, :svd_compact, :svd_full, + :svd_vals, ] - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.copy_input(::typeof($f), a::$T) - return _copy_input_squareeye($f, a.a) ⊗ _copy_input_squareeye($f, a.b) - end - end - end -end - -for f in [ - :default_eig_algorithm, - :default_eigh_algorithm, - :default_lq_algorithm, - :default_qr_algorithm, - :default_polar_algorithm, - :default_svd_algorithm, -] - f′ = Symbol("_", f, "_squareeye") + f! = Symbol(f, "!") @eval begin - $f′(a; kwargs...) = $f(a; kwargs...) - $f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...) - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...) - A1, A2 = argument_types(A) - return KroneckerAlgorithm( - $f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...) - ) - end + function MatrixAlgebraKit.$f!(a::Eye, F, ::EyeAlgorithm) + return F end end end -# Defined to avoid type piracy. -_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a) -_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg) +_complex(a::AbstractArray) = complex(a) +_complex(a::Eye{<:Complex}) = a +_complex(a::Eye) = _similar(a, complex(eltype(a))) +_real(a::AbstractArray) = real(a) +_real(a::Eye{<:Real}) = a +_real(a::Eye) = _similar(a, real(eltype(a))) -for f in [:left_null!, :right_null!] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = a - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = a - end +# Implementations of `Eye` factorizations are doing in `initialize_output` +# so they can be used in KroneckerArray factorizations. +function _initialize_output(::typeof(eig_full!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + return _complex.((a, a)) end -for f in [ - :qr_compact!, - :qr_full!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :right_orth!, - :right_polar!, -] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a) - end +function _initialize_output(::typeof(eigh_full!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + return (_real(a), a) end -_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a)) -_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a)) -_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a) -_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a) -for f in [:svd_compact!, :svd_full!] - @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a) - end +function _initialize_output(::typeof(eig_vals!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + # TODO: Use `_diagview`/`_diag`. + return _complex(parent(a)) end - -for f in [ - :eig_full!, - :eigh_full!, - :qr_compact!, - :qr_full!, - :left_orth!, - :left_polar!, - :lq_compact!, - :lq_full!, - :right_orth!, - :right_polar!, - :svd_compact!, - :svd_full!, -] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) - $f′(a, F, alg::SquareEyeAlgorithm) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) - return _initialize_output_squareeye($f, a.a) .⊗ - _initialize_output_squareeye($f, a.b) - end - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::$T, alg::KroneckerAlgorithm - ) - return _initialize_output_squareeye($f, a.a, alg.a) .⊗ - _initialize_output_squareeye($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f( - a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - $f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) - $f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) - return F - end - end - end +function _initialize_output(::typeof(eigh_vals!), a::Eye, ::EyeAlgorithm) + LinearAlgebra.checksquare(a) + # TODO: Use `_diagview`/`_diag`. + return _real(parent(a)) end - -for f in [:left_null!, :right_null!] - f′ = Symbol("_", f, "_squareeye") - @eval begin - $f′(a, F; kwargs...) = $f(a, F; kwargs...) - $f′(a::SquareEye, F) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) - return _initialize_output_squareeye($f, a.a) ⊗ _initialize_output_squareeye($f, a.b) - end - function MatrixAlgebraKit.$f(a::$T, F; kwargs1=(;), kwargs2=(;), kwargs...) - $f′(a.a, F.a; kwargs..., kwargs1...) - $f′(a.b, F.b; kwargs..., kwargs2...) - return F - end - end - end +function _initialize_output(::typeof(svd_compact!), a::Eye, ::EyeAlgorithm) + ax_s = (infimum(axes(a)...), infimum(reverse(axes(a))...)) + ax_u = (axes(a, 1), ax_s[2]) + ax_v = (ax_s[1], axes(a, 2)) + Tr = real(eltype(a)) + return (_similar(a, ax_u), _similar(a, Tr, ax_s), _similar(a, ax_v)) end - -function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::SquareEyeSquareEye) - return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) +function _initialize_output(::typeof(svd_full!), a::Eye, ::EyeAlgorithm) + ax_s = axes(a) + ax_u = (axes(a, 1), axes(a, 1)) + ax_v = (axes(a, 2), axes(a, 2)) + Tr = real(eltype(a)) + return (_similar(a, ax_u), _similar(a, Tr, ax_s), _similar(a, ax_v)) end -function MatrixAlgebraKit.left_null!( - a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... -) - return throw(MethodError(left_null!, (a, F))) +function _initialize_output(::typeof(svd_vals!), a::Eye, ::EyeAlgorithm) + # TODO: Use `_diagview`/`_diag`. + return _real(parent(a)) end -function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::SquareEyeSquareEye) - return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) -end -function MatrixAlgebraKit.right_null!( - a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... -) - return throw(MethodError(right_null!, (a, F))) -end - -_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a) -_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a) -for f in [:eigh_vals!, svd_vals!] +for f in [:left_polar!, :right_polar!, qr_compact!, lq_compact!] @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a)) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a)) + function _initialize_output(::typeof($f), a::Eye, ::EyeAlgorithm) + ax = infimum(axes(a)...) + ax_x = (axes(a, 1), ax) + ax_y = (ax, axes(a, 2)) + return (_similar(a, ax_x), _similar(a, ax_y)) + end end end -for f in [:eig_vals!, :eigh_vals!, :svd_vals!] - f′ = Symbol("_", f, "_squareeye") +for f in [qr_full!, lq_full!] @eval begin - $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) - $f′(a, F, alg::SquareEyeAlgorithm) = F - end - for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] - @eval begin - function MatrixAlgebraKit.initialize_output( - ::typeof($f), a::$T, alg::KroneckerAlgorithm - ) - return _initialize_output_squareeye($f, a.a, alg.a) ⊗ - _initialize_output_squareeye($f, a.b, alg.b) - end - function MatrixAlgebraKit.$f( - a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... - ) - $f′(a.a, F.a, alg.a; kwargs..., kwargs1...) - $f′(a.b, F.b, alg.b; kwargs..., kwargs2...) - return F - end + function _initialize_output(::typeof($f), a::Eye, ::EyeAlgorithm) + ax = supremum(axes(a)...) + ax_x = (axes(a, 1), ax) + ax_y = (ax, axes(a, 2)) + return (_similar(a, ax_x), _similar(a, ax_y)) end end end -using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! - -struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy - strategy::T -end - -# Avoid instantiating the identity. -function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2}) - return a.a ⊗ a.b[I[1].b, I[2].b] -end -function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) - return a.a[I[1].a, I[2].a] ⊗ a.b -end -function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2}) - return a -end - -using FillArrays: OnesVector -const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} -const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} -const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} - -function MatrixAlgebraKit.findtruncated( - values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.a, prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> x.a == i, prods) == length(values.a) - end - return (:) × I_data -end -function MatrixAlgebraKit.findtruncated( - values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.b, prods)) - # Drop truncations that occur within the identity. - I_data = filter(I_data) do i - return count(x -> x.b == i, prods) == length(values.b) - end - return I_data × (:) -end -function MatrixAlgebraKit.findtruncated( - values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy -) - return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) -end - -for f in [:eig_trunc!, :eigh_trunc!] +for f in [:left_orth!, :right_orth!] @eval begin - function MatrixAlgebraKit.truncate!( - ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy - ) - return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) - end - function MatrixAlgebraKit.truncate!( - ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy - ) - I = findtruncated(diagview(D), strategy) - return (D[I, I], V[(:) × (:), I]) + function _initialize_output(::typeof($f), a::Eye) + ax = infimum(axes(a)...) + ax_x = (axes(a, 1), ax) + ax_y = (ax, axes(a, 2)) + return (_similar(a, ax_x), _similar(a, ax_y)) end end end - -function MatrixAlgebraKit.truncate!( - f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy -) - return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) -end -function MatrixAlgebraKit.truncate!( - ::typeof(svd_trunc!), - (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, - strategy::KroneckerTruncationStrategy, -) - I = findtruncated(diagview(S), strategy) - return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) -end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 5cf8b77..1e3935d 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -60,6 +60,10 @@ using MatrixAlgebraKit: svd_compact, svd_full +function _copy_input(f::F, a::AbstractMatrix) where {F} + return copy_input(f, a) +end + for f in [ :eig_full, :eigh_full, @@ -74,7 +78,7 @@ for f in [ ] @eval begin function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) - return copy_input($f, a.a) ⊗ copy_input($f, a.b) + return _copy_input($f, a.a) ⊗ _copy_input($f, a.b) end end end @@ -87,13 +91,17 @@ for f in [ :default_polar_algorithm, :default_svd_algorithm, ] + _f = Symbol(:_, f) @eval begin + function $_f(A::Type{<:AbstractMatrix}; kwargs...) + return $f(A; kwargs...) + end function MatrixAlgebraKit.$f( A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs... ) A1, A2 = argument_types(A) return KroneckerAlgorithm( - $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) + $_f(A1; kwargs..., kwargs1...), $_f(A2; kwargs..., kwargs2...) ) end end @@ -112,6 +120,12 @@ function MatrixAlgebraKit.default_algorithm( return default_qr_algorithm(A; kwargs...) end +# Allows overloading while avoiding type piracy. +function _initialize_output(f::F, a::AbstractMatrix, alg::AbstractAlgorithm) where {F} + return initialize_output(f, a, alg) +end +_initialize_output(f::F, a::AbstractMatrix) where {F} = initialize_output(f, a) + for f in [ :eig_full!, :eigh_full!, @@ -128,7 +142,7 @@ for f in [ function MatrixAlgebraKit.initialize_output( ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm ) - return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) + return _initialize_output($f, a.a, alg.a) .⊗ _initialize_output($f, a.b, alg.b) end function MatrixAlgebraKit.$f( a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... @@ -145,7 +159,7 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!] function MatrixAlgebraKit.initialize_output( ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm ) - return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b) + return _initialize_output($f, a.a, alg.a) ⊗ _initialize_output($f, a.b, alg.b) end function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) $f(a.a, F.a, alg.a) @@ -158,7 +172,7 @@ end for f in [:left_orth!, :right_orth!] @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) - return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) + return _initialize_output($f, a.a) .⊗ _initialize_output($f, a.b) end end end @@ -166,7 +180,7 @@ end for f in [:left_null!, :right_null!] @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) - return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) + return _initialize_output($f, a.a) ⊗ _initialize_output($f, a.b) end function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...) $f(a.a, F.a; kwargs..., kwargs1...) diff --git a/test/test_fillarrays_matrixalgebrakit.jl b/test/test_fillarrays_matrixalgebrakit.jl index bbc08d8..355c36d 100644 --- a/test/test_fillarrays_matrixalgebrakit.jl +++ b/test/test_fillarrays_matrixalgebrakit.jl @@ -29,19 +29,19 @@ herm(a) = parent(hermitianpart(a)) @testset "MatrixAlgebraKit + Eye" begin for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{complex(elt)} @test arguments(v, 1) isa Eye{complex(elt)} - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d @test arguments(d, 2) isa Eye{complex(elt)} @test arguments(v, 2) isa Eye{complex(elt)} - a = Eye{elt}(3) ⊗ Eye{elt}(3) + a = Eye{elt}(3, 3) ⊗ Eye{elt}(3, 3) d, v = @constinferred eig_full(a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{complex(elt)} @@ -51,20 +51,20 @@ herm(a) = parent(hermitianpart(a)) end for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) - d, v = @constinferred eigh_full(a) + a = Eye{elt}(3, 3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) + d, v = @constinferred eigh_full($a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{real(elt)} @test arguments(v, 1) isa Eye{elt} - a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) - d, v = @constinferred eigh_full(a) + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3, 3) + d, v = @constinferred eigh_full($a) @test a * v ≈ v * d @test arguments(d, 2) isa Eye{real(elt)} @test arguments(v, 2) isa Eye{elt} - a = Eye{elt}(3) ⊗ Eye{elt}(3) - d, v = @constinferred eigh_full(a) + a = Eye{elt}(3, 3) ⊗ Eye{elt}(3, 3) + d, v = @constinferred eigh_full($a) @test a * v ≈ v * d @test arguments(d, 1) isa Eye{real(elt)} @test arguments(d, 2) isa Eye{real(elt)} @@ -72,26 +72,26 @@ herm(a) = parent(hermitianpart(a)) @test arguments(v, 2) isa Eye{elt} end - for f in (eig_trunc, eigh_trunc) - a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) - d, v = f(a; trunc=(; maxrank=7)) - @test a * v ≈ v * d - @test arguments(d, 1) isa Eye - @test arguments(v, 1) isa Eye - @test size(d) == (6, 6) - @test size(v) == (9, 6) + ## for f in (eig_trunc, eigh_trunc) + ## a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) + ## d, v = f(a; trunc=(; maxrank=7)) + ## @test a * v ≈ v * d + ## @test arguments(d, 1) isa Eye + ## @test arguments(v, 1) isa Eye + ## @test size(d) == (6, 6) + ## @test size(v) == (9, 6) - a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) - d, v = f(a; trunc=(; maxrank=7)) - @test a * v ≈ v * d - @test arguments(d, 2) isa Eye - @test arguments(v, 2) isa Eye - @test size(d) == (6, 6) - @test size(v) == (9, 6) + ## a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) + ## d, v = f(a; trunc=(; maxrank=7)) + ## @test a * v ≈ v * d + ## @test arguments(d, 2) isa Eye + ## @test arguments(v, 2) isa Eye + ## @test size(d) == (6, 6) + ## @test size(v) == (9, 6) - a = Eye(3) ⊗ Eye(3) - @test_throws ArgumentError f(a) - end + ## a = Eye(3) ⊗ Eye(3) + ## @test_throws ArgumentError f(a) + ## end for f in (eig_vals, eigh_vals) a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) @@ -116,22 +116,22 @@ herm(a) = parent(hermitianpart(a)) end for f in ( - left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar + left_orth, right_orth, left_polar, right_polar, qr_compact, lq_compact, qr_full, lq_full ) - a = Eye(3) ⊗ randn(3, 3) - x, y = @constinferred f(a) + a = Eye(3, 3) ⊗ randn(3, 3) + x, y = @constinferred f($a) @test x * y ≈ a @test arguments(x, 1) isa Eye @test arguments(y, 1) isa Eye - a = randn(3, 3) ⊗ Eye(3) - x, y = @constinferred f(a) + a = randn(3, 3) ⊗ Eye(3, 3) + x, y = @constinferred f($a) @test x * y ≈ a @test arguments(x, 2) isa Eye @test arguments(y, 2) isa Eye - a = Eye(3) ⊗ Eye(3) - x, y = f(a) + a = Eye(3, 3) ⊗ Eye(3, 3) + x, y = @constinferred f($a) @test x * y ≈ a @test arguments(x, 1) isa Eye @test arguments(y, 1) isa Eye @@ -141,8 +141,8 @@ herm(a) = parent(hermitianpart(a)) for f in (svd_compact, svd_full) for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - u, s, v = @constinferred f(a) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) + u, s, v = @constinferred f($a) @test u * s * v ≈ a @test eltype(u) === elt @test eltype(s) === real(elt) @@ -151,8 +151,8 @@ herm(a) = parent(hermitianpart(a)) @test arguments(s, 1) isa Eye{real(elt)} @test arguments(v, 1) isa Eye{elt} - a = randn(elt, 3, 3) ⊗ Eye{elt}(3) - u, s, v = @constinferred f(a) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3, 3) + u, s, v = @constinferred f($a) @test u * s * v ≈ a @test eltype(u) === elt @test eltype(s) === real(elt) @@ -161,8 +161,8 @@ herm(a) = parent(hermitianpart(a)) @test arguments(s, 2) isa Eye{real(elt)} @test arguments(v, 2) isa Eye{elt} - a = Eye{elt}(3) ⊗ Eye{elt}(3) - u, s, v = @constinferred f(a) + a = Eye{elt}(3, 3) ⊗ Eye{elt}(3, 3) + u, s, v = @constinferred f($a) @test u * s * v ≈ a @test eltype(u) === elt @test eltype(s) === real(elt) @@ -176,47 +176,47 @@ herm(a) = parent(hermitianpart(a)) end end - # svd_trunc - for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - # TODO: Type inference is broken for `svd_trunc`, - # look into fixing it. - # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 1) isa Eye{elt} - @test arguments(s, 1) isa Eye{real(elt)} - @test arguments(v, 1) isa Eye{elt} - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) - end + ## # svd_trunc + ## for elt in (Float32, ComplexF32) + ## a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + ## # TODO: Type inference is broken for `svd_trunc`, + ## # look into fixing it. + ## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + ## u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + ## @test eltype(u) === elt + ## @test eltype(s) === real(elt) + ## @test eltype(v) === elt + ## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + ## @test Matrix(u * s * v) ≈ u′ * s′ * v′ + ## @test arguments(u, 1) isa Eye{elt} + ## @test arguments(s, 1) isa Eye{real(elt)} + ## @test arguments(v, 1) isa Eye{elt} + ## @test size(u) == (9, 6) + ## @test size(s) == (6, 6) + ## @test size(v) == (6, 9) + ## end - for elt in (Float32, ComplexF32) - a = randn(elt, 3, 3) ⊗ Eye{elt}(3) - # TODO: Type inference is broken for `svd_trunc`, - # look into fixing it. - # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - @test eltype(u) === elt - @test eltype(s) === real(elt) - @test eltype(v) === elt - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 2) isa Eye{elt} - @test arguments(s, 2) isa Eye{real(elt)} - @test arguments(v, 2) isa Eye{elt} - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) - end + ## for elt in (Float32, ComplexF32) + ## a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + ## # TODO: Type inference is broken for `svd_trunc`, + ## # look into fixing it. + ## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + ## u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + ## @test eltype(u) === elt + ## @test eltype(s) === real(elt) + ## @test eltype(v) === elt + ## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + ## @test Matrix(u * s * v) ≈ u′ * s′ * v′ + ## @test arguments(u, 2) isa Eye{elt} + ## @test arguments(s, 2) isa Eye{real(elt)} + ## @test arguments(v, 2) isa Eye{elt} + ## @test size(u) == (9, 6) + ## @test size(s) == (6, 6) + ## @test size(v) == (6, 9) + ## end - a = Eye(3) ⊗ Eye(3) - @test_throws ArgumentError svd_trunc(a) + ## a = Eye(3) ⊗ Eye(3) + ## @test_throws ArgumentError svd_trunc(a) # svd_vals for elt in (Float32, ComplexF32) @@ -245,31 +245,31 @@ herm(a) = parent(hermitianpart(a)) @test arguments(d, 2) isa Ones{real(elt)} end - # left_null - a = Eye(3) ⊗ randn(3, 3) - n = @constinferred left_null(a) - @test norm(n' * a) ≈ 0 - @test arguments(n, 1) isa Eye + ## # left_null + ## a = Eye(3) ⊗ randn(3, 3) + ## n = @constinferred left_null(a) + ## @test norm(n' * a) ≈ 0 + ## @test arguments(n, 1) isa Eye - a = randn(3, 3) ⊗ Eye(3) - n = @constinferred left_null(a) - @test norm(n' * a) ≈ 0 - @test arguments(n, 2) isa Eye + ## a = randn(3, 3) ⊗ Eye(3) + ## n = @constinferred left_null(a) + ## @test norm(n' * a) ≈ 0 + ## @test arguments(n, 2) isa Eye - a = Eye(3) ⊗ Eye(3) - @test_throws MethodError left_null(a) + ## a = Eye(3) ⊗ Eye(3) + ## @test_throws MethodError left_null(a) - # right_null - a = Eye(3) ⊗ randn(3, 3) - n = @constinferred right_null(a) - @test norm(a * n') ≈ 0 - @test arguments(n, 1) isa Eye + ## # right_null + ## a = Eye(3) ⊗ randn(3, 3) + ## n = @constinferred right_null(a) + ## @test norm(a * n') ≈ 0 + ## @test arguments(n, 1) isa Eye - a = randn(3, 3) ⊗ Eye(3) - n = @constinferred right_null(a) - @test norm(a * n') ≈ 0 - @test arguments(n, 2) isa Eye + ## a = randn(3, 3) ⊗ Eye(3) + ## n = @constinferred right_null(a) + ## @test norm(a * n') ≈ 0 + ## @test arguments(n, 2) isa Eye - a = Eye(3) ⊗ Eye(3) - @test_throws MethodError right_null(a) + ## a = Eye(3) ⊗ Eye(3) + ## @test_throws MethodError right_null(a) end From cb5282832b8ef403802816d17b2423439e9c1d16 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 15 Jun 2025 18:09:57 -0400 Subject: [PATCH 2/5] Truncate --- src/KroneckerArrays.jl | 1 + src/fillarrays/matrixalgebrakit.jl | 3 + src/fillarrays/matrixalgebrakit_truncate.jl | 81 ++++++++++++++++++++ src/matrixalgebrakit.jl | 4 +- test/test_fillarrays_matrixalgebrakit.jl | 84 ++++++++++----------- 5 files changed, 130 insertions(+), 43 deletions(-) create mode 100644 src/fillarrays/matrixalgebrakit_truncate.jl diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 597872c..4552a2f 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -9,5 +9,6 @@ include("matrixalgebrakit.jl") include("fillarrays/kroneckerarray.jl") include("fillarrays/linearalgebra.jl") include("fillarrays/matrixalgebrakit.jl") +include("fillarrays/matrixalgebrakit_truncate.jl") end diff --git a/src/fillarrays/matrixalgebrakit.jl b/src/fillarrays/matrixalgebrakit.jl index 7a1f138..d8ed07b 100644 --- a/src/fillarrays/matrixalgebrakit.jl +++ b/src/fillarrays/matrixalgebrakit.jl @@ -15,6 +15,9 @@ function supremum(r1::AbstractRange, r2::AbstractUnitRange) end end +# Allow customization for `Eye`. +_diagview(a::Eye) = parent(a) + function _copy_input(f::F, a::Eye) where {F} return a end diff --git a/src/fillarrays/matrixalgebrakit_truncate.jl b/src/fillarrays/matrixalgebrakit_truncate.jl new file mode 100644 index 0000000..ae50f27 --- /dev/null +++ b/src/fillarrays/matrixalgebrakit_truncate.jl @@ -0,0 +1,81 @@ +using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! + +struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy + strategy::T +end + +# Avoid instantiating the identity. +function Base.getindex(a::EyeKronecker, I::Vararg{CartesianProduct{Colon},2}) + return a.a ⊗ a.b[I[1].b, I[2].b] +end +function Base.getindex(a::KroneckerEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) + return a.a[I[1].a, I[2].a] ⊗ a.b +end +function Base.getindex(a::EyeEye, I::Vararg{CartesianProduct{Colon,Colon},2}) + return a +end + +using FillArrays: OnesVector +const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} +const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} +const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} + +function MatrixAlgebraKit.findtruncated( + values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.a, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.a == i, prods) == length(values.a) + end + return (:) × I_data +end +function MatrixAlgebraKit.findtruncated( + values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.b, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.b == i, prods) == length(values.b) + end + return I_data × (:) +end +function MatrixAlgebraKit.findtruncated( + values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy +) + return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) +end + +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy + ) + return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) + end + function MatrixAlgebraKit.truncate!( + ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy + ) + I = findtruncated(diagview(D), strategy) + return (D[I, I], V[(:) × (:), I]) + end + end +end + +function MatrixAlgebraKit.truncate!( + f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy +) + return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) +end +function MatrixAlgebraKit.truncate!( + ::typeof(svd_trunc!), + (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, + strategy::KroneckerTruncationStrategy, +) + I = findtruncated(diagview(S), strategy) + return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) +end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 1e3935d..6e8f479 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -32,8 +32,10 @@ using MatrixAlgebraKit: truncate! using MatrixAlgebraKit: MatrixAlgebraKit, diagview +# Allow customization for `Eye`. +_diagview(a::AbstractMatrix) = diagview(a) function MatrixAlgebraKit.diagview(a::KroneckerMatrix) - return diagview(a.a) ⊗ diagview(a.b) + return _diagview(a.a) ⊗ _diagview(a.b) end struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm diff --git a/test/test_fillarrays_matrixalgebrakit.jl b/test/test_fillarrays_matrixalgebrakit.jl index 355c36d..778b701 100644 --- a/test/test_fillarrays_matrixalgebrakit.jl +++ b/test/test_fillarrays_matrixalgebrakit.jl @@ -176,51 +176,51 @@ herm(a) = parent(hermitianpart(a)) end end - ## # svd_trunc - ## for elt in (Float32, ComplexF32) - ## a = Eye{elt}(3) ⊗ randn(elt, 3, 3) - ## # TODO: Type inference is broken for `svd_trunc`, - ## # look into fixing it. - ## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - ## u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - ## @test eltype(u) === elt - ## @test eltype(s) === real(elt) - ## @test eltype(v) === elt - ## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - ## @test Matrix(u * s * v) ≈ u′ * s′ * v′ - ## @test arguments(u, 1) isa Eye{elt} - ## @test arguments(s, 1) isa Eye{real(elt)} - ## @test arguments(v, 1) isa Eye{elt} - ## @test size(u) == (9, 6) - ## @test size(s) == (6, 6) - ## @test size(v) == (6, 9) - ## end + # svd_trunc + for elt in (Float32, ComplexF32) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 1) isa Eye{elt} + @test arguments(s, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end - ## for elt in (Float32, ComplexF32) - ## a = randn(elt, 3, 3) ⊗ Eye{elt}(3) - ## # TODO: Type inference is broken for `svd_trunc`, - ## # look into fixing it. - ## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - ## u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - ## @test eltype(u) === elt - ## @test eltype(s) === real(elt) - ## @test eltype(v) === elt - ## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - ## @test Matrix(u * s * v) ≈ u′ * s′ * v′ - ## @test arguments(u, 2) isa Eye{elt} - ## @test arguments(s, 2) isa Eye{real(elt)} - ## @test arguments(v, 2) isa Eye{elt} - ## @test size(u) == (9, 6) - ## @test size(s) == (6, 6) - ## @test size(v) == (6, 9) - ## end + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 2) isa Eye{elt} + @test arguments(s, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end - ## a = Eye(3) ⊗ Eye(3) - ## @test_throws ArgumentError svd_trunc(a) + a = Eye(3, 3) ⊗ Eye(3, 3) + @test_throws ArgumentError svd_trunc(a) # svd_vals for elt in (Float32, ComplexF32) - a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + a = Eye{elt}(3, 3) ⊗ randn(elt, 3, 3) d = @constinferred svd_vals(a) d′ = svd_vals(Matrix(a)) @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) @@ -246,12 +246,12 @@ herm(a) = parent(hermitianpart(a)) end ## # left_null - ## a = Eye(3) ⊗ randn(3, 3) + ## a = Eye(3, 3) ⊗ randn(3, 3) ## n = @constinferred left_null(a) ## @test norm(n' * a) ≈ 0 ## @test arguments(n, 1) isa Eye - ## a = randn(3, 3) ⊗ Eye(3) + ## a = randn(3, 3) ⊗ Eye(3, 3) ## n = @constinferred left_null(a) ## @test norm(n' * a) ≈ 0 ## @test arguments(n, 2) isa Eye From 394fec2ab441bd4fea006cca4415e33fab20da14 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 15 Jun 2025 18:27:31 -0400 Subject: [PATCH 3/5] Add back support for null space --- src/fillarrays/matrixalgebrakit.jl | 16 +++++++++ src/matrixalgebrakit.jl | 8 +++-- test/test_fillarrays_matrixalgebrakit.jl | 44 ++++++++++++------------ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/fillarrays/matrixalgebrakit.jl b/src/fillarrays/matrixalgebrakit.jl index d8ed07b..093760b 100644 --- a/src/fillarrays/matrixalgebrakit.jl +++ b/src/fillarrays/matrixalgebrakit.jl @@ -148,3 +148,19 @@ for f in [:left_orth!, :right_orth!] end end end + +for f in [:left_null!, :right_null!] + _f = Symbol(:_, f) + @eval begin + function _initialize_output(::typeof($f), a::Eye) + return a + end + function $_f(a::Eye, F) + return F + end + + function MatrixAlgebraKit.$f(a::EyeEye, F; kwargs...) + return throw(MethodError($f, (a, F))) + end + end +end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 6e8f479..7e88608 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -180,13 +180,17 @@ for f in [:left_orth!, :right_orth!] end for f in [:left_null!, :right_null!] + _f = Symbol(:_, f) @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) return _initialize_output($f, a.a) ⊗ _initialize_output($f, a.b) end + function $_f(a::AbstractMatrix, F; kwargs...) + return $f(a, F; kwargs...) + end function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...) - $f(a.a, F.a; kwargs..., kwargs1...) - $f(a.b, F.b; kwargs..., kwargs2...) + $_f(a.a, F.a; kwargs..., kwargs1...) + $_f(a.b, F.b; kwargs..., kwargs2...) return F end end diff --git a/test/test_fillarrays_matrixalgebrakit.jl b/test/test_fillarrays_matrixalgebrakit.jl index 778b701..b3ed2d5 100644 --- a/test/test_fillarrays_matrixalgebrakit.jl +++ b/test/test_fillarrays_matrixalgebrakit.jl @@ -245,31 +245,31 @@ herm(a) = parent(hermitianpart(a)) @test arguments(d, 2) isa Ones{real(elt)} end - ## # left_null - ## a = Eye(3, 3) ⊗ randn(3, 3) - ## n = @constinferred left_null(a) - ## @test norm(n' * a) ≈ 0 - ## @test arguments(n, 1) isa Eye + # left_null + a = Eye(3, 3) ⊗ randn(3, 3) + n = @constinferred left_null(a) + @test norm(n' * a) ≈ 0 + @test arguments(n, 1) isa Eye - ## a = randn(3, 3) ⊗ Eye(3, 3) - ## n = @constinferred left_null(a) - ## @test norm(n' * a) ≈ 0 - ## @test arguments(n, 2) isa Eye + a = randn(3, 3) ⊗ Eye(3, 3) + n = @constinferred left_null(a) + @test norm(n' * a) ≈ 0 + @test arguments(n, 2) isa Eye - ## a = Eye(3) ⊗ Eye(3) - ## @test_throws MethodError left_null(a) + a = Eye(3) ⊗ Eye(3) + @test_throws MethodError left_null(a) - ## # right_null - ## a = Eye(3) ⊗ randn(3, 3) - ## n = @constinferred right_null(a) - ## @test norm(a * n') ≈ 0 - ## @test arguments(n, 1) isa Eye + # right_null + a = Eye(3) ⊗ randn(3, 3) + n = @constinferred right_null(a) + @test norm(a * n') ≈ 0 + @test arguments(n, 1) isa Eye - ## a = randn(3, 3) ⊗ Eye(3) - ## n = @constinferred right_null(a) - ## @test norm(a * n') ≈ 0 - ## @test arguments(n, 2) isa Eye + a = randn(3, 3) ⊗ Eye(3) + n = @constinferred right_null(a) + @test norm(a * n') ≈ 0 + @test arguments(n, 2) isa Eye - ## a = Eye(3) ⊗ Eye(3) - ## @test_throws MethodError right_null(a) + a = Eye(3) ⊗ Eye(3) + @test_throws MethodError right_null(a) end From 25e9074cd3462792060ae86633bfcf92a07bb7b5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 15 Jun 2025 18:29:26 -0400 Subject: [PATCH 4/5] Bring back eig_trunc tests --- test/test_fillarrays_matrixalgebrakit.jl | 36 ++++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test/test_fillarrays_matrixalgebrakit.jl b/test/test_fillarrays_matrixalgebrakit.jl index b3ed2d5..d785bd6 100644 --- a/test/test_fillarrays_matrixalgebrakit.jl +++ b/test/test_fillarrays_matrixalgebrakit.jl @@ -72,26 +72,26 @@ herm(a) = parent(hermitianpart(a)) @test arguments(v, 2) isa Eye{elt} end - ## for f in (eig_trunc, eigh_trunc) - ## a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) - ## d, v = f(a; trunc=(; maxrank=7)) - ## @test a * v ≈ v * d - ## @test arguments(d, 1) isa Eye - ## @test arguments(v, 1) isa Eye - ## @test size(d) == (6, 6) - ## @test size(v) == (9, 6) + for f in (eig_trunc, eigh_trunc) + a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye + @test arguments(v, 1) isa Eye + @test size(d) == (6, 6) + @test size(v) == (9, 6) - ## a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) - ## d, v = f(a; trunc=(; maxrank=7)) - ## @test a * v ≈ v * d - ## @test arguments(d, 2) isa Eye - ## @test arguments(v, 2) isa Eye - ## @test size(d) == (6, 6) - ## @test size(v) == (9, 6) + a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 2) isa Eye + @test arguments(v, 2) isa Eye + @test size(d) == (6, 6) + @test size(v) == (9, 6) - ## a = Eye(3) ⊗ Eye(3) - ## @test_throws ArgumentError f(a) - ## end + a = Eye(3) ⊗ Eye(3) + @test_throws ArgumentError f(a) + end for f in (eig_vals, eigh_vals) a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) From abe442ffd63695beebc9ac1f548ba555e133a2f7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 15 Jun 2025 18:31:32 -0400 Subject: [PATCH 5/5] Bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 196b1f5..6562906 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.12" +version = "0.1.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"