@@ -608,12 +608,22 @@ for f in [
608608 end
609609end
610610
611- for f in (:eig , :eigh , :lq , :qr , :polar , :svd )
612- ff = Symbol (" default_" , f, " _algorithm" )
611+ for f in [
612+ :default_eig_algorithm ,
613+ :default_eigh_algorithm ,
614+ :default_lq_algorithm ,
615+ :default_qr_algorithm ,
616+ :default_polar_algorithm ,
617+ :default_svd_algorithm ,
618+ ]
613619 @eval begin
614- function MatrixAlgebraKit. $ff (A:: Type{<:KroneckerMatrix} ; kwargs... )
620+ function MatrixAlgebraKit. $f (
621+ A:: Type{<:KroneckerMatrix} ; kwargs1= (;), kwargs2= (;), kwargs...
622+ )
615623 A1, A2 = argument_types (A)
616- return KroneckerAlgorithm ($ ff (A1; kwargs... ), $ ff (A2; kwargs... ))
624+ return KroneckerAlgorithm (
625+ $ f (A1; kwargs... , kwargs1... ), $ f (A2; kwargs... , kwargs2... )
626+ )
617627 end
618628 end
619629end
@@ -631,7 +641,7 @@ function MatrixAlgebraKit.default_algorithm(
631641 return default_qr_algorithm (A; kwargs... )
632642end
633643
634- for f in (
644+ for f in [
635645 :eig_full! ,
636646 :eigh_full! ,
637647 :qr_compact! ,
@@ -642,22 +652,24 @@ for f in (
642652 :right_polar! ,
643653 :svd_compact! ,
644654 :svd_full! ,
645- )
655+ ]
646656 @eval begin
647657 function MatrixAlgebraKit. initialize_output (
648658 :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
649659 )
650660 return initialize_output ($ f, a. a, alg. a) .⊗ initialize_output ($ f, a. b, alg. b)
651661 end
652- function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs... )
653- $ f (a. a, Base. Fix2 (getfield, :a ).(F), alg. a; kwargs... )
654- $ f (a. b, Base. Fix2 (getfield, :b ).(F), alg. b; kwargs... )
662+ function MatrixAlgebraKit. $f (
663+ a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
664+ )
665+ $ f (a. a, Base. Fix2 (getfield, :a ).(F), alg. a; kwargs... , kwargs1... )
666+ $ f (a. b, Base. Fix2 (getfield, :b ).(F), alg. b; kwargs... , kwargs2... )
655667 return F
656668 end
657669 end
658670end
659671
660- for f in ( :eig_vals! , :eigh_vals! , :svd_vals! )
672+ for f in [ :eig_vals! , :eigh_vals! , :svd_vals! ]
661673 @eval begin
662674 function MatrixAlgebraKit. initialize_output (
663675 :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
@@ -672,7 +684,7 @@ for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
672684 end
673685end
674686
675- for f in ( :eig_trunc! , :eigh_trunc! , :svd_trunc! )
687+ for f in [ :eig_trunc! , :eigh_trunc! , :svd_trunc! ]
676688 @eval begin
677689 function MatrixAlgebraKit. truncate! (
678690 :: typeof ($ f),
@@ -684,25 +696,146 @@ for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
684696 end
685697end
686698
687- for f in ( :left_orth! , :right_orth! )
699+ for f in [ :left_orth! , :right_orth! ]
688700 @eval begin
689701 function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: KroneckerMatrix )
690702 return initialize_output ($ f, a. a) .⊗ initialize_output ($ f, a. b)
691703 end
692704 end
693705end
694706
695- for f in ( :left_null! , :right_null! )
707+ for f in [ :left_null! , :right_null! ]
696708 @eval begin
697709 function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: KroneckerMatrix )
698710 return initialize_output ($ f, a. a) ⊗ initialize_output ($ f, a. b)
699711 end
700- function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F; kwargs... )
701- $ f (a. a, F. a; kwargs... )
702- $ f (a. b, F. b; kwargs... )
712+ function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F; kwargs1 = (;), kwargs2 = (;), kwargs... )
713+ $ f (a. a, F. a; kwargs... , kwargs1 ... )
714+ $ f (a. b, F. b; kwargs... , kwargs2 ... )
703715 return F
704716 end
705717 end
706718end
707719
720+ # ###################################################################################
721+ # Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and
722+ # `A ⊗ Eye(n)` where `A`.
723+ # TODO : Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34
724+ # is merged.
725+
726+ using FillArrays: SquareEye
727+ const SquareEyeKronecker{T,A<: SquareEye{T} ,B<: AbstractMatrix{T} } = KroneckerMatrix{T,A,B}
728+ const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
729+ const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
730+
731+ struct SquareEyeAlgorithm <: AbstractAlgorithm end
732+
733+ # Defined to avoid type piracy.
734+ _copy_input_squareeye (f:: F , a) where {F} = copy_input (f, a)
735+ _copy_input_squareeye (f:: F , a:: SquareEye ) where {F} = a
736+
737+ for f in [
738+ :eig_full ,
739+ :eigh_full ,
740+ :qr_compact ,
741+ :qr_full ,
742+ :left_polar ,
743+ :lq_compact ,
744+ :lq_full ,
745+ :right_polar ,
746+ :svd_compact ,
747+ :svd_full ,
748+ ]
749+ for T in [:SquareEyeKronecker , :KroneckerSquareEye , :SquareEyeSquareEye ]
750+ @eval begin
751+ function MatrixAlgebraKit. copy_input (:: typeof ($ f), a:: $T )
752+ return _copy_input_squareeye ($ f, a. a) ⊗ _copy_input_squareeye ($ f, a. b)
753+ end
754+ end
755+ end
756+ end
757+
758+ for f in [
759+ :default_eig_algorithm ,
760+ :default_eigh_algorithm ,
761+ :default_lq_algorithm ,
762+ :default_qr_algorithm ,
763+ :default_polar_algorithm ,
764+ :default_svd_algorithm ,
765+ ]
766+ f′ = Symbol (" _" , f, " _squareeye" )
767+ @eval begin
768+ $ f′ (a) = $ f (a)
769+ $ f′ (a:: Type{<:SquareEye} ) = SquareEyeAlgorithm ()
770+ end
771+ for T in [:SquareEyeKronecker , :KroneckerSquareEye , :SquareEyeSquareEye ]
772+ @eval begin
773+ function MatrixAlgebraKit. $f (A:: Type{<:$T} ; kwargs1= (;), kwargs2= (;), kwargs... )
774+ A1, A2 = argument_types (A)
775+ return KroneckerAlgorithm (
776+ $ f′ (A1; kwargs... , kwargs1... ), $ f′ (A2; kwargs... , kwargs2... )
777+ )
778+ end
779+ end
780+ end
781+ end
782+
783+ # Defined to avoid type piracy.
784+ _initialize_output_squareeye (f:: F , a, alg) where {F} = initialize_output (f, a, alg)
785+ for f in [
786+ :eig_full! ,
787+ :eigh_full! ,
788+ :qr_compact! ,
789+ :qr_full! ,
790+ :left_polar! ,
791+ :lq_compact! ,
792+ :lq_full! ,
793+ :right_polar! ,
794+ ]
795+ @eval begin
796+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a)
797+ end
798+ end
799+ for f in [:svd_compact! , :svd_full! ]
800+ @eval begin
801+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a, a)
802+ end
803+ end
804+
805+ for f in [
806+ :eig_full! ,
807+ :eigh_full! ,
808+ :qr_compact! ,
809+ :qr_full! ,
810+ :left_polar! ,
811+ :lq_compact! ,
812+ :lq_full! ,
813+ :right_polar! ,
814+ :svd_compact! ,
815+ :svd_full! ,
816+ ]
817+ f′ = Symbol (" _" , f, " _squareeye" )
818+ @eval begin
819+ $ f′ (a, F, alg; kwargs... ) = $ f (a, F, alg; kwargs... )
820+ $ f′ (a, F, alg:: SquareEyeAlgorithm ) = F
821+ end
822+ for T in [:SquareEyeKronecker , :KroneckerSquareEye , :SquareEyeSquareEye ]
823+ @eval begin
824+ function MatrixAlgebraKit. initialize_output (
825+ :: typeof ($ f), a:: $T , alg:: KroneckerAlgorithm
826+ )
827+ return _initialize_output_squareeye ($ f, a. a, alg. a) .⊗
828+ _initialize_output_squareeye ($ f, a. b, alg. b)
829+ end
830+ function MatrixAlgebraKit. $f (
831+ a:: $T , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
832+ )
833+ $ f′ (a. a, Base. Fix2 (getfield, :a ).(F), alg. a; kwargs... , kwargs1... )
834+ $ f′ (a. b, Base. Fix2 (getfield, :b ).(F), alg. b; kwargs... , kwargs2... )
835+ return F
836+ end
837+ end
838+ end
839+ end
840+
708841end
0 commit comments