|
158 | 158 |
|
159 | 159 | arguments(a::KroneckerArray) = (a.a, a.b) |
160 | 160 | arguments(a::KroneckerArray, n::Int) = arguments(a)[n] |
| 161 | +argument_types(a::KroneckerArray) = argument_types(typeof(a)) |
| 162 | +argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B) |
161 | 163 |
|
162 | 164 | function Base.print_array(io::IO, a::KroneckerArray) |
163 | 165 | Base.print_array(io, a.a) |
@@ -609,12 +611,26 @@ end |
609 | 611 | for f in (:eig, :eigh, :lq, :qr, :polar, :svd) |
610 | 612 | ff = Symbol("default_", f, "_algorithm") |
611 | 613 | @eval begin |
612 | | - function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...) |
613 | | - return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...)) |
| 614 | + function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...) |
| 615 | + A1, A2 = argument_types(A) |
| 616 | + return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...)) |
614 | 617 | end |
615 | 618 | end |
616 | 619 | end |
617 | 620 |
|
| 621 | +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. |
| 622 | +function MatrixAlgebraKit.default_algorithm( |
| 623 | + ::typeof(qr_compact!), A::Type{<:KroneckerMatrix}; kwargs... |
| 624 | +) |
| 625 | + return default_qr_algorithm(A; kwargs...) |
| 626 | +end |
| 627 | +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. |
| 628 | +function MatrixAlgebraKit.default_algorithm( |
| 629 | + ::typeof(qr_full!), A::Type{<:KroneckerMatrix}; kwargs... |
| 630 | +) |
| 631 | + return default_qr_algorithm(A; kwargs...) |
| 632 | +end |
| 633 | + |
618 | 634 | for f in ( |
619 | 635 | :eig_full!, |
620 | 636 | :eigh_full!, |
@@ -689,75 +705,4 @@ for f in (:left_null!, :right_null!) |
689 | 705 | end |
690 | 706 | end |
691 | 707 |
|
692 | | -# Special case for `FillArrays.Eye` matrices. |
693 | | -struct EyeAlgorithm <: AbstractAlgorithm end |
694 | | - |
695 | | -for f in [ |
696 | | - :eig_full, |
697 | | - :eigh_full, |
698 | | - :qr_compact, |
699 | | - :qr_full, |
700 | | - :left_polar, |
701 | | - :lq_compact, |
702 | | - :lq_full, |
703 | | - :right_polar, |
704 | | - :svd_compact, |
705 | | - :svd_full, |
706 | | -] |
707 | | - @eval begin |
708 | | - MatrixAlgebraKit.copy_input(::typeof($f), a::Eye) = a |
709 | | - end |
710 | | -end |
711 | | - |
712 | | -for f in (:eig, :eigh, :lq, :qr, :polar, :svd) |
713 | | - ff = Symbol("default_", f, "_algorithm") |
714 | | - @eval begin |
715 | | - function MatrixAlgebraKit.$ff(a::Eye; kwargs...) |
716 | | - return EyeAlgorithm() |
717 | | - end |
718 | | - end |
719 | | -end |
720 | | - |
721 | | -for f in ( |
722 | | - :eig_full!, |
723 | | - :eigh_full!, |
724 | | - :qr_compact!, |
725 | | - :qr_full!, |
726 | | - :left_polar!, |
727 | | - :lq_compact!, |
728 | | - :lq_full!, |
729 | | - :right_polar!, |
730 | | -) |
731 | | - @eval begin |
732 | | - nfactors(::typeof($f)) = 2 |
733 | | - end |
734 | | -end |
735 | | -for f in (:svd_compact!, :svd_full!) |
736 | | - @eval begin |
737 | | - nfactors(::typeof($f)) = 3 |
738 | | - end |
739 | | -end |
740 | | - |
741 | | -for f in ( |
742 | | - :eig_full!, |
743 | | - :eigh_full!, |
744 | | - :qr_compact!, |
745 | | - :qr_full!, |
746 | | - :left_polar!, |
747 | | - :lq_compact!, |
748 | | - :lq_full!, |
749 | | - :right_polar!, |
750 | | - :svd_compact!, |
751 | | - :svd_full!, |
752 | | -) |
753 | | - @eval begin |
754 | | - function MatrixAlgebraKit.initialize_output(::typeof($f), a::Eye, alg::EyeAlgorithm) |
755 | | - return ntuple(_ -> a, nfactors($f)) |
756 | | - end |
757 | | - function MatrixAlgebraKit.$f(a::Eye, F, alg::EyeAlgorithm; kwargs...) |
758 | | - return ntuple(_ -> a, nfactors($f)) |
759 | | - end |
760 | | - end |
761 | | -end |
762 | | - |
763 | 708 | end |
0 commit comments