Skip to content

Commit 3a93e83

Browse files
authored
MatrixAlgebraKit factorizations on Kronecker with Eye (#7)
1 parent e5325c1 commit 3a93e83

File tree

4 files changed

+260
-19
lines changed

4 files changed

+260
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.4"
4+
version = "0.1.5"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

src/KroneckerArrays.jl

Lines changed: 168 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,11 @@ using MatrixAlgebraKit:
582582
eigh_full,
583583
qr_compact,
584584
qr_full,
585+
left_orth,
585586
left_polar,
586587
lq_compact,
587588
lq_full,
589+
right_orth,
588590
right_polar,
589591
svd_compact,
590592
svd_full
@@ -608,12 +610,22 @@ for f in [
608610
end
609611
end
610612

611-
for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
612-
ff = Symbol("default_", f, "_algorithm")
613+
for f in [
614+
:default_eig_algorithm,
615+
:default_eigh_algorithm,
616+
:default_lq_algorithm,
617+
:default_qr_algorithm,
618+
:default_polar_algorithm,
619+
:default_svd_algorithm,
620+
]
613621
@eval begin
614-
function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...)
622+
function MatrixAlgebraKit.$f(
623+
A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs...
624+
)
615625
A1, A2 = argument_types(A)
616-
return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...))
626+
return KroneckerAlgorithm(
627+
$f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...)
628+
)
617629
end
618630
end
619631
end
@@ -631,7 +643,7 @@ function MatrixAlgebraKit.default_algorithm(
631643
return default_qr_algorithm(A; kwargs...)
632644
end
633645

634-
for f in (
646+
for f in [
635647
:eig_full!,
636648
:eigh_full!,
637649
:qr_compact!,
@@ -642,22 +654,24 @@ for f in (
642654
:right_polar!,
643655
:svd_compact!,
644656
:svd_full!,
645-
)
657+
]
646658
@eval begin
647659
function MatrixAlgebraKit.initialize_output(
648660
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
649661
)
650662
return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b)
651663
end
652-
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...)
653-
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...)
654-
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...)
664+
function MatrixAlgebraKit.$f(
665+
a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
666+
)
667+
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...)
668+
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...)
655669
return F
656670
end
657671
end
658672
end
659673

660-
for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
674+
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
661675
@eval begin
662676
function MatrixAlgebraKit.initialize_output(
663677
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
@@ -672,7 +686,7 @@ for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
672686
end
673687
end
674688

675-
for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
689+
for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!]
676690
@eval begin
677691
function MatrixAlgebraKit.truncate!(
678692
::typeof($f),
@@ -684,25 +698,163 @@ for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
684698
end
685699
end
686700

687-
for f in (:left_orth!, :right_orth!)
701+
for f in [:left_orth!, :right_orth!]
688702
@eval begin
689703
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
690704
return initialize_output($f, a.a) .⊗ initialize_output($f, a.b)
691705
end
692706
end
693707
end
694708

695-
for f in (:left_null!, :right_null!)
709+
for f in [:left_null!, :right_null!]
696710
@eval begin
697711
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
698712
return initialize_output($f, a.a) initialize_output($f, a.b)
699713
end
700-
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...)
701-
$f(a.a, F.a; kwargs...)
702-
$f(a.b, F.b; kwargs...)
714+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...)
715+
$f(a.a, F.a; kwargs..., kwargs1...)
716+
$f(a.b, F.b; kwargs..., kwargs2...)
703717
return F
704718
end
705719
end
706720
end
707721

722+
####################################################################################
723+
# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and
724+
# `A ⊗ Eye(n)` where `A`.
725+
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34
726+
# is merged.
727+
728+
using FillArrays: SquareEye
729+
const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
730+
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
731+
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
732+
733+
struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm
734+
kwargs::KWargs
735+
end
736+
SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...))
737+
738+
# Defined to avoid type piracy.
739+
_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a)
740+
_copy_input_squareeye(f::F, a::SquareEye) where {F} = a
741+
742+
for f in [
743+
:eig_full,
744+
:eigh_full,
745+
:qr_compact,
746+
:qr_full,
747+
:left_orth,
748+
:left_polar,
749+
:lq_compact,
750+
:lq_full,
751+
:right_orth,
752+
:right_polar,
753+
:svd_compact,
754+
:svd_full,
755+
]
756+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
757+
@eval begin
758+
function MatrixAlgebraKit.copy_input(::typeof($f), a::$T)
759+
return _copy_input_squareeye($f, a.a) _copy_input_squareeye($f, a.b)
760+
end
761+
end
762+
end
763+
end
764+
765+
for f in [
766+
:default_eig_algorithm,
767+
:default_eigh_algorithm,
768+
:default_lq_algorithm,
769+
:default_qr_algorithm,
770+
:default_polar_algorithm,
771+
:default_svd_algorithm,
772+
]
773+
f′ = Symbol("_", f, "_squareeye")
774+
@eval begin
775+
$f′(a; kwargs...) = $f(a; kwargs...)
776+
$f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...)
777+
end
778+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
779+
@eval begin
780+
function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...)
781+
A1, A2 = argument_types(A)
782+
return KroneckerAlgorithm(
783+
$f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...)
784+
)
785+
end
786+
end
787+
end
788+
end
789+
790+
# Defined to avoid type piracy.
791+
_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a)
792+
_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg)
793+
794+
for f in [
795+
:eig_full!,
796+
:eigh_full!,
797+
:qr_compact!,
798+
:qr_full!,
799+
:left_orth!,
800+
:left_polar!,
801+
:lq_compact!,
802+
:lq_full!,
803+
:right_orth!,
804+
:right_polar!,
805+
]
806+
@eval begin
807+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a)
808+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a)
809+
end
810+
end
811+
for f in [:svd_compact!, :svd_full!]
812+
@eval begin
813+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a)
814+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a)
815+
end
816+
end
817+
818+
for f in [
819+
:eig_full!,
820+
:eigh_full!,
821+
:qr_compact!,
822+
:qr_full!,
823+
:left_orth!,
824+
:left_polar!,
825+
:lq_compact!,
826+
:lq_full!,
827+
:right_orth!,
828+
:right_polar!,
829+
:svd_compact!,
830+
:svd_full!,
831+
]
832+
f′ = Symbol("_", f, "_squareeye")
833+
@eval begin
834+
$f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...)
835+
$f′(a, F, alg::SquareEyeAlgorithm) = F
836+
end
837+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
838+
@eval begin
839+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T)
840+
return _initialize_output_squareeye($f, a.a) .⊗
841+
_initialize_output_squareeye($f, a.b)
842+
end
843+
function MatrixAlgebraKit.initialize_output(
844+
::typeof($f), a::$T, alg::KroneckerAlgorithm
845+
)
846+
return _initialize_output_squareeye($f, a.a, alg.a) .⊗
847+
_initialize_output_squareeye($f, a.b, alg.b)
848+
end
849+
function MatrixAlgebraKit.$f(
850+
a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
851+
)
852+
$f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...)
853+
$f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...)
854+
return F
855+
end
856+
end
857+
end
858+
end
859+
708860
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
77
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
88
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
99
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10+
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1011

1112
[compat]
1213
Aqua = "0.8"
@@ -17,3 +18,4 @@ MatrixAlgebraKit = "0.2"
1718
SafeTestsets = "0.1"
1819
Suppressor = "0.2"
1920
Test = "1.10"
21+
TestExtras = "0.3"

test/test_matrixalgebrakit.jl

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using KroneckerArrays:
1+
using FillArrays: Eye
2+
using KroneckerArrays: , arguments
23
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
34
using MatrixAlgebraKit:
45
eig_full,
@@ -22,8 +23,9 @@ using MatrixAlgebraKit:
2223
svd_trunc,
2324
svd_vals
2425
using Test: @test, @test_throws, @testset
26+
using TestExtras: @constinferred
2527

26-
herm(a) = hermitianpart(a).data
28+
herm(a) = parent(hermitianpart(a))
2729

2830
@testset "MatrixAlgebraKit" begin
2931
elt = Float32
@@ -117,3 +119,88 @@ herm(a) = hermitianpart(a).data
117119
s = svd_vals(a)
118120
@test s diag(svd_compact(a)[2])
119121
end
122+
123+
@testset "MatrixAlgebraKit + Eye" begin
124+
125+
# TODO:
126+
# eig_trunc
127+
# eig_vals
128+
# eigh_trunc
129+
# eigh_vals
130+
# left_null
131+
# right_null
132+
# svd_trunc
133+
# svd_vals
134+
135+
for f in (eig_full, eigh_full)
136+
a = Eye(3) parent(hermitianpart(randn(3, 3)))
137+
d, v = @constinferred f(a)
138+
@test a * v v * d
139+
@test arguments(d, 1) isa Eye
140+
@test arguments(v, 1) isa Eye
141+
142+
a = parent(hermitianpart(randn(3, 3))) Eye(3)
143+
d, v = @constinferred f(a)
144+
@test a * v v * d
145+
@test arguments(d, 2) isa Eye
146+
@test arguments(v, 2) isa Eye
147+
148+
a = Eye(3) Eye(3)
149+
d, v = @constinferred f(a)
150+
@test a * v v * d
151+
@test arguments(d, 1) isa Eye
152+
@test arguments(d, 2) isa Eye
153+
@test arguments(v, 1) isa Eye
154+
@test arguments(v, 2) isa Eye
155+
end
156+
157+
for f in (
158+
left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar
159+
)
160+
a = Eye(3) randn(3, 3)
161+
x, y = f(a)
162+
@test x * y a
163+
@test arguments(x, 1) isa Eye
164+
@test arguments(y, 1) isa Eye
165+
166+
a = randn(3, 3) Eye(3)
167+
x, y = f(a)
168+
@test x * y a
169+
@test arguments(x, 2) isa Eye
170+
@test arguments(y, 2) isa Eye
171+
172+
a = Eye(3) Eye(3)
173+
x, y = f(a)
174+
@test x * y a
175+
@test arguments(x, 1) isa Eye
176+
@test arguments(y, 1) isa Eye
177+
@test arguments(x, 2) isa Eye
178+
@test arguments(y, 2) isa Eye
179+
end
180+
181+
for f in (svd_compact, svd_full)
182+
a = Eye(3) randn(3, 3)
183+
u, s, v = f(a)
184+
@test u * s * v a
185+
@test arguments(u, 1) isa Eye
186+
@test arguments(s, 1) isa Eye
187+
@test arguments(v, 1) isa Eye
188+
189+
a = randn(3, 3) Eye(3)
190+
u, s, v = f(a)
191+
@test u * s * v a
192+
@test arguments(u, 2) isa Eye
193+
@test arguments(s, 2) isa Eye
194+
@test arguments(v, 2) isa Eye
195+
196+
a = Eye(3) Eye(3)
197+
u, s, v = f(a)
198+
@test u * s * v a
199+
@test arguments(u, 1) isa Eye
200+
@test arguments(s, 1) isa Eye
201+
@test arguments(v, 1) isa Eye
202+
@test arguments(u, 2) isa Eye
203+
@test arguments(s, 2) isa Eye
204+
@test arguments(v, 2) isa Eye
205+
end
206+
end

0 commit comments

Comments
 (0)