Skip to content

Commit 2ba5f40

Browse files
committed
eig and svd vals
1 parent 3a93e83 commit 2ba5f40

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

src/KroneckerArrays.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,9 @@ end
579579
using MatrixAlgebraKit:
580580
copy_input,
581581
eig_full,
582+
eig_vals,
582583
eigh_full,
584+
eigh_vals,
583585
qr_compact,
584586
qr_full,
585587
left_orth,
@@ -741,7 +743,9 @@ _copy_input_squareeye(f::F, a::SquareEye) where {F} = a
741743

742744
for f in [
743745
:eig_full,
746+
:eig_vals,
744747
:eigh_full,
748+
:eigh_vals,
745749
:qr_compact,
746750
:qr_full,
747751
:left_orth,
@@ -857,4 +861,36 @@ for f in [
857861
end
858862
end
859863

864+
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
865+
@eval begin
866+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = parent(a)
867+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = parent(a)
868+
end
869+
end
870+
871+
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
872+
f′ = Symbol("_", f, "_squareeye")
873+
@eval begin
874+
$f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...)
875+
$f′(a, F, alg::SquareEyeAlgorithm) = F
876+
end
877+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
878+
@eval begin
879+
function MatrixAlgebraKit.initialize_output(
880+
::typeof($f), a::$T, alg::KroneckerAlgorithm
881+
)
882+
return _initialize_output_squareeye($f, a.a, alg.a)
883+
_initialize_output_squareeye($f, a.b, alg.b)
884+
end
885+
function MatrixAlgebraKit.$f(
886+
a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
887+
)
888+
$f′(a.a, F.a, alg.a; kwargs..., kwargs1...)
889+
$f′(a.b, F.b, alg.b; kwargs..., kwargs2...)
890+
return F
891+
end
892+
end
893+
end
894+
end
895+
860896
end

test/test_matrixalgebrakit.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FillArrays: Eye
1+
using FillArrays: Eye, Ones
22
using KroneckerArrays: , arguments
33
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
44
using MatrixAlgebraKit:
@@ -154,6 +154,28 @@ end
154154
@test arguments(v, 2) isa Eye
155155
end
156156

157+
for f in (eig_vals, eigh_vals)
158+
a = Eye(3) parent(hermitianpart(randn(3, 3)))
159+
d = @constinferred f(a)
160+
d′ = f(Matrix(a))
161+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
162+
@test arguments(d, 1) isa Ones
163+
@test arguments(d, 2) f(arguments(a, 2))
164+
165+
a = parent(hermitianpart(randn(3, 3))) Eye(3)
166+
d = @constinferred f(a)
167+
d′ = f(Matrix(a))
168+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
169+
@test arguments(d, 2) isa Ones
170+
@test arguments(d, 1) f(arguments(a, 1))
171+
172+
a = Eye(3) Eye(3)
173+
d = @constinferred f(a)
174+
@test d == Ones(3) Ones(3)
175+
@test arguments(d, 1) isa Ones
176+
@test arguments(d, 2) isa Ones
177+
end
178+
157179
for f in (
158180
left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar
159181
)
@@ -203,4 +225,24 @@ end
203225
@test arguments(s, 2) isa Eye
204226
@test arguments(v, 2) isa Eye
205227
end
228+
229+
a = Eye(3) randn(3, 3)
230+
d = @constinferred svd_vals(a)
231+
d′ = svd_vals(Matrix(a))
232+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
233+
@test arguments(d, 1) isa Ones
234+
@test arguments(d, 2) svd_vals(arguments(a, 2))
235+
236+
a = randn(3, 3) Eye(3)
237+
d = @constinferred svd_vals(a)
238+
d′ = svd_vals(Matrix(a))
239+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
240+
@test arguments(d, 2) isa Ones
241+
@test arguments(d, 1) svd_vals(arguments(a, 1))
242+
243+
a = Eye(3) Eye(3)
244+
d = @constinferred svd_vals(a)
245+
@test d == Ones(3) Ones(3)
246+
@test arguments(d, 1) isa Ones
247+
@test arguments(d, 2) isa Ones
206248
end

0 commit comments

Comments
 (0)