diff --git a/Project.toml b/Project.toml index 8ed52d9..29e0e8a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index bdedfc7..06fd869 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -582,9 +582,11 @@ using MatrixAlgebraKit: eigh_full, qr_compact, qr_full, + left_orth, left_polar, lq_compact, lq_full, + right_orth, right_polar, svd_compact, svd_full @@ -608,12 +610,22 @@ for f in [ end end -for f in (:eig, :eigh, :lq, :qr, :polar, :svd) - ff = Symbol("default_", f, "_algorithm") +for f in [ + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, +] @eval begin - function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...) + function MatrixAlgebraKit.$f( + A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs... + ) A1, A2 = argument_types(A) - return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...)) + return KroneckerAlgorithm( + $f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...) + ) end end end @@ -631,7 +643,7 @@ function MatrixAlgebraKit.default_algorithm( return default_qr_algorithm(A; kwargs...) end -for f in ( +for f in [ :eig_full!, :eigh_full!, :qr_compact!, @@ -642,22 +654,24 @@ for f in ( :right_polar!, :svd_compact!, :svd_full!, -) +] @eval begin function MatrixAlgebraKit.initialize_output( ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm ) return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) end - function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...) - $f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...) - $f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...) + function MatrixAlgebraKit.$f( + a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... + ) + $f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) + $f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) return F end end end -for f in (:eig_vals!, :eigh_vals!, :svd_vals!) +for f in [:eig_vals!, :eigh_vals!, :svd_vals!] @eval begin function MatrixAlgebraKit.initialize_output( ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm @@ -672,7 +686,7 @@ for f in (:eig_vals!, :eigh_vals!, :svd_vals!) end end -for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!) +for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!] @eval begin function MatrixAlgebraKit.truncate!( ::typeof($f), @@ -684,7 +698,7 @@ for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!) end end -for f in (:left_orth!, :right_orth!) +for f in [:left_orth!, :right_orth!] @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) @@ -692,17 +706,155 @@ for f in (:left_orth!, :right_orth!) end end -for f in (:left_null!, :right_null!) +for f in [:left_null!, :right_null!] @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) end - function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...) - $f(a.a, F.a; kwargs...) - $f(a.b, F.b; kwargs...) + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...) + $f(a.a, F.a; kwargs..., kwargs1...) + $f(a.b, F.b; kwargs..., kwargs2...) return F end end end +#################################################################################### +# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and +# `A ⊗ Eye(n)` where `A`. +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34 +# is merged. + +using FillArrays: SquareEye +const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} +const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} +const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} + +struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm + kwargs::KWargs +end +SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...)) + +# Defined to avoid type piracy. +_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a) +_copy_input_squareeye(f::F, a::SquareEye) where {F} = a + +for f in [ + :eig_full, + :eigh_full, + :qr_compact, + :qr_full, + :left_orth, + :left_polar, + :lq_compact, + :lq_full, + :right_orth, + :right_polar, + :svd_compact, + :svd_full, +] + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.copy_input(::typeof($f), a::$T) + return _copy_input_squareeye($f, a.a) ⊗ _copy_input_squareeye($f, a.b) + end + end + end +end + +for f in [ + :default_eig_algorithm, + :default_eigh_algorithm, + :default_lq_algorithm, + :default_qr_algorithm, + :default_polar_algorithm, + :default_svd_algorithm, +] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a; kwargs...) = $f(a; kwargs...) + $f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...) + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...) + A1, A2 = argument_types(A) + return KroneckerAlgorithm( + $f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...) + ) + end + end + end +end + +# Defined to avoid type piracy. +_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a) +_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg) + +for f in [ + :eig_full!, + :eigh_full!, + :qr_compact!, + :qr_full!, + :left_orth!, + :left_polar!, + :lq_compact!, + :lq_full!, + :right_orth!, + :right_polar!, +] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a) + end +end +for f in [:svd_compact!, :svd_full!] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a) + end +end + +for f in [ + :eig_full!, + :eigh_full!, + :qr_compact!, + :qr_full!, + :left_orth!, + :left_polar!, + :lq_compact!, + :lq_full!, + :right_orth!, + :right_polar!, + :svd_compact!, + :svd_full!, +] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) + $f′(a, F, alg::SquareEyeAlgorithm) = F + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) + return _initialize_output_squareeye($f, a.a) .⊗ + _initialize_output_squareeye($f, a.b) + end + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::$T, alg::KroneckerAlgorithm + ) + return _initialize_output_squareeye($f, a.a, alg.a) .⊗ + _initialize_output_squareeye($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f( + a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... + ) + $f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...) + $f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...) + return F + end + end + end +end + end diff --git a/test/Project.toml b/test/Project.toml index 5d58469..aebc8e9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] Aqua = "0.8" @@ -17,3 +18,4 @@ MatrixAlgebraKit = "0.2" SafeTestsets = "0.1" Suppressor = "0.2" Test = "1.10" +TestExtras = "0.3" diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index 85822fd..a26508d 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,4 +1,5 @@ -using KroneckerArrays: ⊗ +using FillArrays: Eye +using KroneckerArrays: ⊗, arguments using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm using MatrixAlgebraKit: eig_full, @@ -22,8 +23,9 @@ using MatrixAlgebraKit: svd_trunc, svd_vals using Test: @test, @test_throws, @testset +using TestExtras: @constinferred -herm(a) = hermitianpart(a).data +herm(a) = parent(hermitianpart(a)) @testset "MatrixAlgebraKit" begin elt = Float32 @@ -117,3 +119,88 @@ herm(a) = hermitianpart(a).data s = svd_vals(a) @test s ≈ diag(svd_compact(a)[2]) end + +@testset "MatrixAlgebraKit + Eye" begin + + # TODO: + # eig_trunc + # eig_vals + # eigh_trunc + # eigh_vals + # left_null + # right_null + # svd_trunc + # svd_vals + + for f in (eig_full, eigh_full) + a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) + d, v = @constinferred f(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye + @test arguments(v, 1) isa Eye + + a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) + d, v = @constinferred f(a) + @test a * v ≈ v * d + @test arguments(d, 2) isa Eye + @test arguments(v, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + d, v = @constinferred f(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye + @test arguments(d, 2) isa Eye + @test arguments(v, 1) isa Eye + @test arguments(v, 2) isa Eye + end + + for f in ( + left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar + ) + a = Eye(3) ⊗ randn(3, 3) + x, y = f(a) + @test x * y ≈ a + @test arguments(x, 1) isa Eye + @test arguments(y, 1) isa Eye + + a = randn(3, 3) ⊗ Eye(3) + x, y = f(a) + @test x * y ≈ a + @test arguments(x, 2) isa Eye + @test arguments(y, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + x, y = f(a) + @test x * y ≈ a + @test arguments(x, 1) isa Eye + @test arguments(y, 1) isa Eye + @test arguments(x, 2) isa Eye + @test arguments(y, 2) isa Eye + end + + for f in (svd_compact, svd_full) + a = Eye(3) ⊗ randn(3, 3) + u, s, v = f(a) + @test u * s * v ≈ a + @test arguments(u, 1) isa Eye + @test arguments(s, 1) isa Eye + @test arguments(v, 1) isa Eye + + a = randn(3, 3) ⊗ Eye(3) + u, s, v = f(a) + @test u * s * v ≈ a + @test arguments(u, 2) isa Eye + @test arguments(s, 2) isa Eye + @test arguments(v, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + u, s, v = f(a) + @test u * s * v ≈ a + @test arguments(u, 1) isa Eye + @test arguments(s, 1) isa Eye + @test arguments(v, 1) isa Eye + @test arguments(u, 2) isa Eye + @test arguments(s, 2) isa Eye + @test arguments(v, 2) isa Eye + end +end