@@ -582,9 +582,11 @@ using MatrixAlgebraKit:
582582 eigh_full,
583583 qr_compact,
584584 qr_full,
585+ left_orth,
585586 left_polar,
586587 lq_compact,
587588 lq_full,
589+ right_orth,
588590 right_polar,
589591 svd_compact,
590592 svd_full
@@ -608,12 +610,22 @@ for f in [
608610 end
609611end
610612
611- for f in (:eig , :eigh , :lq , :qr , :polar , :svd )
612- ff = Symbol (" default_" , f, " _algorithm" )
613+ for f in [
614+ :default_eig_algorithm ,
615+ :default_eigh_algorithm ,
616+ :default_lq_algorithm ,
617+ :default_qr_algorithm ,
618+ :default_polar_algorithm ,
619+ :default_svd_algorithm ,
620+ ]
613621 @eval begin
614- function MatrixAlgebraKit. $ff (A:: Type{<:KroneckerMatrix} ; kwargs... )
622+ function MatrixAlgebraKit. $f (
623+ A:: Type{<:KroneckerMatrix} ; kwargs1= (;), kwargs2= (;), kwargs...
624+ )
615625 A1, A2 = argument_types (A)
616- return KroneckerAlgorithm ($ ff (A1; kwargs... ), $ ff (A2; kwargs... ))
626+ return KroneckerAlgorithm (
627+ $ f (A1; kwargs... , kwargs1... ), $ f (A2; kwargs... , kwargs2... )
628+ )
617629 end
618630 end
619631end
@@ -631,7 +643,7 @@ function MatrixAlgebraKit.default_algorithm(
631643 return default_qr_algorithm (A; kwargs... )
632644end
633645
634- for f in (
646+ for f in [
635647 :eig_full! ,
636648 :eigh_full! ,
637649 :qr_compact! ,
@@ -642,22 +654,24 @@ for f in (
642654 :right_polar! ,
643655 :svd_compact! ,
644656 :svd_full! ,
645- )
657+ ]
646658 @eval begin
647659 function MatrixAlgebraKit. initialize_output (
648660 :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
649661 )
650662 return initialize_output ($ f, a. a, alg. a) .⊗ initialize_output ($ f, a. b, alg. b)
651663 end
652- function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs... )
653- $ f (a. a, Base. Fix2 (getfield, :a ).(F), alg. a; kwargs... )
654- $ f (a. b, Base. Fix2 (getfield, :b ).(F), alg. b; kwargs... )
664+ function MatrixAlgebraKit. $f (
665+ a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
666+ )
667+ $ f (a. a, Base. Fix2 (getfield, :a ).(F), alg. a; kwargs... , kwargs1... )
668+ $ f (a. b, Base. Fix2 (getfield, :b ).(F), alg. b; kwargs... , kwargs2... )
655669 return F
656670 end
657671 end
658672end
659673
660- for f in ( :eig_vals! , :eigh_vals! , :svd_vals! )
674+ for f in [ :eig_vals! , :eigh_vals! , :svd_vals! ]
661675 @eval begin
662676 function MatrixAlgebraKit. initialize_output (
663677 :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
@@ -672,7 +686,7 @@ for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
672686 end
673687end
674688
675- for f in ( :eig_trunc! , :eigh_trunc! , :svd_trunc! )
689+ for f in [ :eig_trunc! , :eigh_trunc! , :svd_trunc! ]
676690 @eval begin
677691 function MatrixAlgebraKit. truncate! (
678692 :: typeof ($ f),
@@ -684,25 +698,163 @@ for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
684698 end
685699end
686700
687- for f in ( :left_orth! , :right_orth! )
701+ for f in [ :left_orth! , :right_orth! ]
688702 @eval begin
689703 function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: KroneckerMatrix )
690704 return initialize_output ($ f, a. a) .⊗ initialize_output ($ f, a. b)
691705 end
692706 end
693707end
694708
695- for f in ( :left_null! , :right_null! )
709+ for f in [ :left_null! , :right_null! ]
696710 @eval begin
697711 function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: KroneckerMatrix )
698712 return initialize_output ($ f, a. a) ⊗ initialize_output ($ f, a. b)
699713 end
700- function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F; kwargs... )
701- $ f (a. a, F. a; kwargs... )
702- $ f (a. b, F. b; kwargs... )
714+ function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F; kwargs1 = (;), kwargs2 = (;), kwargs... )
715+ $ f (a. a, F. a; kwargs... , kwargs1 ... )
716+ $ f (a. b, F. b; kwargs... , kwargs2 ... )
703717 return F
704718 end
705719 end
706720end
707721
722+ # ###################################################################################
723+ # Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and
724+ # `A ⊗ Eye(n)` where `A`.
725+ # TODO : Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34
726+ # is merged.
727+
728+ using FillArrays: SquareEye
729+ const SquareEyeKronecker{T,A<: SquareEye{T} ,B<: AbstractMatrix{T} } = KroneckerMatrix{T,A,B}
730+ const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
731+ const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
732+
733+ struct SquareEyeAlgorithm{KWargs<: NamedTuple } <: AbstractAlgorithm
734+ kwargs:: KWargs
735+ end
736+ SquareEyeAlgorithm (; kwargs... ) = SquareEyeAlgorithm ((; kwargs... ))
737+
738+ # Defined to avoid type piracy.
739+ _copy_input_squareeye (f:: F , a) where {F} = copy_input (f, a)
740+ _copy_input_squareeye (f:: F , a:: SquareEye ) where {F} = a
741+
742+ for f in [
743+ :eig_full ,
744+ :eigh_full ,
745+ :qr_compact ,
746+ :qr_full ,
747+ :left_orth ,
748+ :left_polar ,
749+ :lq_compact ,
750+ :lq_full ,
751+ :right_orth ,
752+ :right_polar ,
753+ :svd_compact ,
754+ :svd_full ,
755+ ]
756+ for T in [:SquareEyeKronecker , :KroneckerSquareEye , :SquareEyeSquareEye ]
757+ @eval begin
758+ function MatrixAlgebraKit. copy_input (:: typeof ($ f), a:: $T )
759+ return _copy_input_squareeye ($ f, a. a) ⊗ _copy_input_squareeye ($ f, a. b)
760+ end
761+ end
762+ end
763+ end
764+
765+ for f in [
766+ :default_eig_algorithm ,
767+ :default_eigh_algorithm ,
768+ :default_lq_algorithm ,
769+ :default_qr_algorithm ,
770+ :default_polar_algorithm ,
771+ :default_svd_algorithm ,
772+ ]
773+ f′ = Symbol (" _" , f, " _squareeye" )
774+ @eval begin
775+ $ f′ (a; kwargs... ) = $ f (a; kwargs... )
776+ $ f′ (a:: Type{<:SquareEye} ; kwargs... ) = SquareEyeAlgorithm (; kwargs... )
777+ end
778+ for T in [:SquareEyeKronecker , :KroneckerSquareEye , :SquareEyeSquareEye ]
779+ @eval begin
780+ function MatrixAlgebraKit. $f (A:: Type{<:$T} ; kwargs1= (;), kwargs2= (;), kwargs... )
781+ A1, A2 = argument_types (A)
782+ return KroneckerAlgorithm (
783+ $ f′ (A1; kwargs... , kwargs1... ), $ f′ (A2; kwargs... , kwargs2... )
784+ )
785+ end
786+ end
787+ end
788+ end
789+
790+ # Defined to avoid type piracy.
791+ _initialize_output_squareeye (f:: F , a) where {F} = initialize_output (f, a)
792+ _initialize_output_squareeye (f:: F , a, alg) where {F} = initialize_output (f, a, alg)
793+
794+ for f in [
795+ :eig_full! ,
796+ :eigh_full! ,
797+ :qr_compact! ,
798+ :qr_full! ,
799+ :left_orth! ,
800+ :left_polar! ,
801+ :lq_compact! ,
802+ :lq_full! ,
803+ :right_orth! ,
804+ :right_polar! ,
805+ ]
806+ @eval begin
807+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = (a, a)
808+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a)
809+ end
810+ end
811+ for f in [:svd_compact! , :svd_full! ]
812+ @eval begin
813+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = (a, a, a)
814+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a, a)
815+ end
816+ end
817+
818+ for f in [
819+ :eig_full! ,
820+ :eigh_full! ,
821+ :qr_compact! ,
822+ :qr_full! ,
823+ :left_orth! ,
824+ :left_polar! ,
825+ :lq_compact! ,
826+ :lq_full! ,
827+ :right_orth! ,
828+ :right_polar! ,
829+ :svd_compact! ,
830+ :svd_full! ,
831+ ]
832+ f′ = Symbol (" _" , f, " _squareeye" )
833+ @eval begin
834+ $ f′ (a, F, alg; kwargs... ) = $ f (a, F, alg; kwargs... )
835+ $ f′ (a, F, alg:: SquareEyeAlgorithm ) = F
836+ end
837+ for T in [:SquareEyeKronecker , :KroneckerSquareEye , :SquareEyeSquareEye ]
838+ @eval begin
839+ function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: $T )
840+ return _initialize_output_squareeye ($ f, a. a) .⊗
841+ _initialize_output_squareeye ($ f, a. b)
842+ end
843+ function MatrixAlgebraKit. initialize_output (
844+ :: typeof ($ f), a:: $T , alg:: KroneckerAlgorithm
845+ )
846+ return _initialize_output_squareeye ($ f, a. a, alg. a) .⊗
847+ _initialize_output_squareeye ($ f, a. b, alg. b)
848+ end
849+ function MatrixAlgebraKit. $f (
850+ a:: $T , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
851+ )
852+ $ f′ (a. a, Base. Fix2 (getfield, :a ).(F), alg. a; kwargs... , kwargs1... )
853+ $ f′ (a. b, Base. Fix2 (getfield, :b ).(F), alg. b; kwargs... , kwargs2... )
854+ return F
855+ end
856+ end
857+ end
858+ end
859+
708860end
0 commit comments