Skip to content

Commit d9c31b2

Browse files
committed
Fix left_orth, right_orth
1 parent 2e35549 commit d9c31b2

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

src/KroneckerArrays.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,11 @@ using MatrixAlgebraKit:
582582
eigh_full,
583583
qr_compact,
584584
qr_full,
585+
left_orth,
585586
left_polar,
586587
lq_compact,
587588
lq_full,
589+
right_orth,
588590
right_polar,
589591
svd_compact,
590592
svd_full
@@ -728,7 +730,10 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
728730
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
729731
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
730732

731-
struct SquareEyeAlgorithm <: AbstractAlgorithm end
733+
struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm
734+
kwargs::KWargs
735+
end
736+
SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...))
732737

733738
# Defined to avoid type piracy.
734739
_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a)
@@ -739,9 +744,11 @@ for f in [
739744
:eigh_full,
740745
:qr_compact,
741746
:qr_full,
747+
:left_orth,
742748
:left_polar,
743749
:lq_compact,
744750
:lq_full,
751+
:right_orth,
745752
:right_polar,
746753
:svd_compact,
747754
:svd_full,
@@ -765,8 +772,8 @@ for f in [
765772
]
766773
f′ = Symbol("_", f, "_squareeye")
767774
@eval begin
768-
$f′(a) = $f(a)
769-
$f′(a::Type{<:SquareEye}) = SquareEyeAlgorithm()
775+
$f′(a; kwargs...) = $f(a; kwargs...)
776+
$f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...)
770777
end
771778
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
772779
@eval begin
@@ -781,23 +788,29 @@ for f in [
781788
end
782789

783790
# Defined to avoid type piracy.
791+
_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a)
784792
_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg)
793+
785794
for f in [
786795
:eig_full!,
787796
:eigh_full!,
788797
:qr_compact!,
789798
:qr_full!,
799+
:left_orth!,
790800
:left_polar!,
791801
:lq_compact!,
792802
:lq_full!,
803+
:right_orth!,
793804
:right_polar!,
794805
]
795806
@eval begin
807+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a)
796808
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a)
797809
end
798810
end
799811
for f in [:svd_compact!, :svd_full!]
800812
@eval begin
813+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a)
801814
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a)
802815
end
803816
end
@@ -807,9 +820,11 @@ for f in [
807820
:eigh_full!,
808821
:qr_compact!,
809822
:qr_full!,
823+
:left_orth!,
810824
:left_polar!,
811825
:lq_compact!,
812826
:lq_full!,
827+
:right_orth!,
813828
:right_polar!,
814829
:svd_compact!,
815830
:svd_full!,
@@ -821,6 +836,10 @@ for f in [
821836
end
822837
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
823838
@eval begin
839+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T)
840+
return _initialize_output_squareeye($f, a.a) .⊗
841+
_initialize_output_squareeye($f, a.b)
842+
end
824843
function MatrixAlgebraKit.initialize_output(
825844
::typeof($f), a::$T, alg::KroneckerAlgorithm
826845
)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
77
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
88
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
99
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10+
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1011

1112
[compat]
1213
Aqua = "0.8"
@@ -17,3 +18,4 @@ MatrixAlgebraKit = "0.2"
1718
SafeTestsets = "0.1"
1819
Suppressor = "0.2"
1920
Test = "1.10"
21+
TestExtras = "0.3"

test/test_matrixalgebrakit.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,7 @@ end
156156
end
157157

158158
for f in (
159-
#=left_orth,=#left_polar,
160-
lq_compact,
161-
lq_full,
162-
qr_compact,
163-
qr_full,
164-
#=right_orth,=#right_polar,
159+
left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar
165160
)
166161
a = Eye(3) randn(3, 3)
167162
x, y = f(a)

0 commit comments

Comments
 (0)