Skip to content

Commit 2e35549

Browse files
committed
[WIP] MatrixAlgebraKit factorizations on Kronecker with Eye
1 parent e5325c1 commit 2e35549

File tree

2 files changed

+218
-18
lines changed

2 files changed

+218
-18
lines changed

src/KroneckerArrays.jl

Lines changed: 149 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -608,12 +608,22 @@ for f in [
608608
end
609609
end
610610

611-
for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
612-
ff = Symbol("default_", f, "_algorithm")
611+
for f in [
612+
:default_eig_algorithm,
613+
:default_eigh_algorithm,
614+
:default_lq_algorithm,
615+
:default_qr_algorithm,
616+
:default_polar_algorithm,
617+
:default_svd_algorithm,
618+
]
613619
@eval begin
614-
function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...)
620+
function MatrixAlgebraKit.$f(
621+
A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs...
622+
)
615623
A1, A2 = argument_types(A)
616-
return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...))
624+
return KroneckerAlgorithm(
625+
$f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...)
626+
)
617627
end
618628
end
619629
end
@@ -631,7 +641,7 @@ function MatrixAlgebraKit.default_algorithm(
631641
return default_qr_algorithm(A; kwargs...)
632642
end
633643

634-
for f in (
644+
for f in [
635645
:eig_full!,
636646
:eigh_full!,
637647
:qr_compact!,
@@ -642,22 +652,24 @@ for f in (
642652
:right_polar!,
643653
:svd_compact!,
644654
:svd_full!,
645-
)
655+
]
646656
@eval begin
647657
function MatrixAlgebraKit.initialize_output(
648658
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
649659
)
650660
return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b)
651661
end
652-
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...)
653-
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...)
654-
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...)
662+
function MatrixAlgebraKit.$f(
663+
a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
664+
)
665+
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...)
666+
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...)
655667
return F
656668
end
657669
end
658670
end
659671

660-
for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
672+
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
661673
@eval begin
662674
function MatrixAlgebraKit.initialize_output(
663675
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
@@ -672,7 +684,7 @@ for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
672684
end
673685
end
674686

675-
for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
687+
for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!]
676688
@eval begin
677689
function MatrixAlgebraKit.truncate!(
678690
::typeof($f),
@@ -684,25 +696,146 @@ for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
684696
end
685697
end
686698

687-
for f in (:left_orth!, :right_orth!)
699+
for f in [:left_orth!, :right_orth!]
688700
@eval begin
689701
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
690702
return initialize_output($f, a.a) .⊗ initialize_output($f, a.b)
691703
end
692704
end
693705
end
694706

695-
for f in (:left_null!, :right_null!)
707+
for f in [:left_null!, :right_null!]
696708
@eval begin
697709
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
698710
return initialize_output($f, a.a) initialize_output($f, a.b)
699711
end
700-
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...)
701-
$f(a.a, F.a; kwargs...)
702-
$f(a.b, F.b; kwargs...)
712+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...)
713+
$f(a.a, F.a; kwargs..., kwargs1...)
714+
$f(a.b, F.b; kwargs..., kwargs2...)
703715
return F
704716
end
705717
end
706718
end
707719

720+
####################################################################################
721+
# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and
722+
# `A ⊗ Eye(n)` where `A`.
723+
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34
724+
# is merged.
725+
726+
using FillArrays: SquareEye
727+
const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
728+
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
729+
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
730+
731+
struct SquareEyeAlgorithm <: AbstractAlgorithm end
732+
733+
# Defined to avoid type piracy.
734+
_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a)
735+
_copy_input_squareeye(f::F, a::SquareEye) where {F} = a
736+
737+
for f in [
738+
:eig_full,
739+
:eigh_full,
740+
:qr_compact,
741+
:qr_full,
742+
:left_polar,
743+
:lq_compact,
744+
:lq_full,
745+
:right_polar,
746+
:svd_compact,
747+
:svd_full,
748+
]
749+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
750+
@eval begin
751+
function MatrixAlgebraKit.copy_input(::typeof($f), a::$T)
752+
return _copy_input_squareeye($f, a.a) _copy_input_squareeye($f, a.b)
753+
end
754+
end
755+
end
756+
end
757+
758+
for f in [
759+
:default_eig_algorithm,
760+
:default_eigh_algorithm,
761+
:default_lq_algorithm,
762+
:default_qr_algorithm,
763+
:default_polar_algorithm,
764+
:default_svd_algorithm,
765+
]
766+
f′ = Symbol("_", f, "_squareeye")
767+
@eval begin
768+
$f′(a) = $f(a)
769+
$f′(a::Type{<:SquareEye}) = SquareEyeAlgorithm()
770+
end
771+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
772+
@eval begin
773+
function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...)
774+
A1, A2 = argument_types(A)
775+
return KroneckerAlgorithm(
776+
$f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...)
777+
)
778+
end
779+
end
780+
end
781+
end
782+
783+
# Defined to avoid type piracy.
784+
_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg)
785+
for f in [
786+
:eig_full!,
787+
:eigh_full!,
788+
:qr_compact!,
789+
:qr_full!,
790+
:left_polar!,
791+
:lq_compact!,
792+
:lq_full!,
793+
:right_polar!,
794+
]
795+
@eval begin
796+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a)
797+
end
798+
end
799+
for f in [:svd_compact!, :svd_full!]
800+
@eval begin
801+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a)
802+
end
803+
end
804+
805+
for f in [
806+
:eig_full!,
807+
:eigh_full!,
808+
:qr_compact!,
809+
:qr_full!,
810+
:left_polar!,
811+
:lq_compact!,
812+
:lq_full!,
813+
:right_polar!,
814+
:svd_compact!,
815+
:svd_full!,
816+
]
817+
f′ = Symbol("_", f, "_squareeye")
818+
@eval begin
819+
$f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...)
820+
$f′(a, F, alg::SquareEyeAlgorithm) = F
821+
end
822+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
823+
@eval begin
824+
function MatrixAlgebraKit.initialize_output(
825+
::typeof($f), a::$T, alg::KroneckerAlgorithm
826+
)
827+
return _initialize_output_squareeye($f, a.a, alg.a) .⊗
828+
_initialize_output_squareeye($f, a.b, alg.b)
829+
end
830+
function MatrixAlgebraKit.$f(
831+
a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
832+
)
833+
$f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...)
834+
$f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...)
835+
return F
836+
end
837+
end
838+
end
839+
end
840+
708841
end

test/test_matrixalgebrakit.jl

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using KroneckerArrays:
1+
using FillArrays: Eye
2+
using KroneckerArrays: , arguments
23
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
34
using MatrixAlgebraKit:
45
eig_full,
@@ -22,8 +23,9 @@ using MatrixAlgebraKit:
2223
svd_trunc,
2324
svd_vals
2425
using Test: @test, @test_throws, @testset
26+
using TestExtras: @constinferred
2527

26-
herm(a) = hermitianpart(a).data
28+
herm(a) = parent(hermitianpart(a))
2729

2830
@testset "MatrixAlgebraKit" begin
2931
elt = Float32
@@ -117,3 +119,68 @@ herm(a) = hermitianpart(a).data
117119
s = svd_vals(a)
118120
@test s diag(svd_compact(a)[2])
119121
end
122+
123+
@testset "MatrixAlgebraKit + Eye" begin
124+
125+
# eig_trunc
126+
# eig_vals
127+
# eigh_trunc
128+
# eigh_vals
129+
# left_null
130+
# right_null
131+
# svd_compact
132+
# svd_full
133+
# svd_trunc
134+
# svd_vals
135+
136+
for f in (eig_full, eigh_full)
137+
a = Eye(3) parent(hermitianpart(randn(3, 3)))
138+
d, v = @constinferred f(a)
139+
@test a * v v * d
140+
@test arguments(d, 1) isa Eye
141+
@test arguments(v, 1) isa Eye
142+
143+
a = parent(hermitianpart(randn(3, 3))) Eye(3)
144+
d, v = @constinferred f(a)
145+
@test a * v v * d
146+
@test arguments(d, 2) isa Eye
147+
@test arguments(v, 2) isa Eye
148+
149+
a = Eye(3) Eye(3)
150+
d, v = @constinferred f(a)
151+
@test a * v v * d
152+
@test arguments(d, 1) isa Eye
153+
@test arguments(d, 2) isa Eye
154+
@test arguments(v, 1) isa Eye
155+
@test arguments(v, 2) isa Eye
156+
end
157+
158+
for f in (
159+
#=left_orth,=#left_polar,
160+
lq_compact,
161+
lq_full,
162+
qr_compact,
163+
qr_full,
164+
#=right_orth,=#right_polar,
165+
)
166+
a = Eye(3) randn(3, 3)
167+
x, y = f(a)
168+
@test x * y a
169+
@test arguments(x, 1) isa Eye
170+
@test arguments(y, 1) isa Eye
171+
172+
a = randn(3, 3) Eye(3)
173+
x, y = f(a)
174+
@test x * y a
175+
@test arguments(x, 2) isa Eye
176+
@test arguments(y, 2) isa Eye
177+
178+
a = Eye(3) Eye(3)
179+
x, y = f(a)
180+
@test x * y a
181+
@test arguments(x, 1) isa Eye
182+
@test arguments(y, 1) isa Eye
183+
@test arguments(x, 2) isa Eye
184+
@test arguments(y, 2) isa Eye
185+
end
186+
end

0 commit comments

Comments
 (0)