Skip to content

Commit 6af509f

Browse files
authored
eig[h]_vals, svd_vals, left_null, right_null (#8)
1 parent 3a93e83 commit 6af509f

File tree

3 files changed

+161
-12
lines changed

3 files changed

+161
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

src/KroneckerArrays.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,13 +579,17 @@ end
579579
using MatrixAlgebraKit:
580580
copy_input,
581581
eig_full,
582+
eig_vals,
582583
eigh_full,
584+
eigh_vals,
583585
qr_compact,
584586
qr_full,
587+
left_null,
585588
left_orth,
586589
left_polar,
587590
lq_compact,
588591
lq_full,
592+
right_null,
589593
right_orth,
590594
right_polar,
591595
svd_compact,
@@ -741,13 +745,17 @@ _copy_input_squareeye(f::F, a::SquareEye) where {F} = a
741745

742746
for f in [
743747
:eig_full,
748+
:eig_vals,
744749
:eigh_full,
750+
:eigh_vals,
745751
:qr_compact,
746752
:qr_full,
753+
:left_null,
747754
:left_orth,
748755
:left_polar,
749756
:lq_compact,
750757
:lq_full,
758+
:right_null,
751759
:right_orth,
752760
:right_polar,
753761
:svd_compact,
@@ -791,6 +799,12 @@ end
791799
_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a)
792800
_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg)
793801

802+
for f in [:left_null!, :right_null!]
803+
@eval begin
804+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = a
805+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = a
806+
end
807+
end
794808
for f in [
795809
:eig_full!,
796810
:eigh_full!,
@@ -857,4 +871,74 @@ for f in [
857871
end
858872
end
859873

874+
for f in [:left_null!, :right_null!]
875+
f′ = Symbol("_", f, "_squareeye")
876+
@eval begin
877+
$f′(a, F; kwargs...) = $f(a, F; kwargs...)
878+
$f′(a::SquareEye, F) = F
879+
end
880+
for T in [:SquareEyeKronecker, :KroneckerSquareEye]
881+
@eval begin
882+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T)
883+
return _initialize_output_squareeye($f, a.a) _initialize_output_squareeye($f, a.b)
884+
end
885+
function MatrixAlgebraKit.$f(a::$T, F; kwargs1=(;), kwargs2=(;), kwargs...)
886+
$f′(a.a, F.a; kwargs..., kwargs1...)
887+
$f′(a.b, F.b; kwargs..., kwargs2...)
888+
return F
889+
end
890+
end
891+
end
892+
end
893+
894+
function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::SquareEyeSquareEye)
895+
return _initialize_output_squareeye(f, a.a) _initialize_output_squareeye(f, a.b)
896+
end
897+
function MatrixAlgebraKit.left_null!(
898+
a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs...
899+
)
900+
return throw(MethodError(left_null!, (a, F)))
901+
end
902+
903+
function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::SquareEyeSquareEye)
904+
return _initialize_output_squareeye(f, a.a) _initialize_output_squareeye(f, a.b)
905+
end
906+
function MatrixAlgebraKit.right_null!(
907+
a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs...
908+
)
909+
return throw(MethodError(right_null!, (a, F)))
910+
end
911+
912+
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
913+
@eval begin
914+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = parent(a)
915+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = parent(a)
916+
end
917+
end
918+
919+
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
920+
f′ = Symbol("_", f, "_squareeye")
921+
@eval begin
922+
$f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...)
923+
$f′(a, F, alg::SquareEyeAlgorithm) = F
924+
end
925+
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
926+
@eval begin
927+
function MatrixAlgebraKit.initialize_output(
928+
::typeof($f), a::$T, alg::KroneckerAlgorithm
929+
)
930+
return _initialize_output_squareeye($f, a.a, alg.a)
931+
_initialize_output_squareeye($f, a.b, alg.b)
932+
end
933+
function MatrixAlgebraKit.$f(
934+
a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
935+
)
936+
$f′(a.a, F.a, alg.a; kwargs..., kwargs1...)
937+
$f′(a.b, F.b, alg.b; kwargs..., kwargs2...)
938+
return F
939+
end
940+
end
941+
end
942+
end
943+
860944
end

test/test_matrixalgebrakit.jl

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FillArrays: Eye
1+
using FillArrays: Eye, Ones
22
using KroneckerArrays: , arguments
33
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
44
using MatrixAlgebraKit:
@@ -124,13 +124,8 @@ end
124124

125125
# TODO:
126126
# eig_trunc
127-
# eig_vals
128127
# eigh_trunc
129-
# eigh_vals
130-
# left_null
131-
# right_null
132128
# svd_trunc
133-
# svd_vals
134129

135130
for f in (eig_full, eigh_full)
136131
a = Eye(3) parent(hermitianpart(randn(3, 3)))
@@ -154,17 +149,39 @@ end
154149
@test arguments(v, 2) isa Eye
155150
end
156151

152+
for f in (eig_vals, eigh_vals)
153+
a = Eye(3) parent(hermitianpart(randn(3, 3)))
154+
d = @constinferred f(a)
155+
d′ = f(Matrix(a))
156+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
157+
@test arguments(d, 1) isa Ones
158+
@test arguments(d, 2) f(arguments(a, 2))
159+
160+
a = parent(hermitianpart(randn(3, 3))) Eye(3)
161+
d = @constinferred f(a)
162+
d′ = f(Matrix(a))
163+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
164+
@test arguments(d, 2) isa Ones
165+
@test arguments(d, 1) f(arguments(a, 1))
166+
167+
a = Eye(3) Eye(3)
168+
d = @constinferred f(a)
169+
@test d == Ones(3) Ones(3)
170+
@test arguments(d, 1) isa Ones
171+
@test arguments(d, 2) isa Ones
172+
end
173+
157174
for f in (
158175
left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar
159176
)
160177
a = Eye(3) randn(3, 3)
161-
x, y = f(a)
178+
x, y = @constinferred f(a)
162179
@test x * y a
163180
@test arguments(x, 1) isa Eye
164181
@test arguments(y, 1) isa Eye
165182

166183
a = randn(3, 3) Eye(3)
167-
x, y = f(a)
184+
x, y = @constinferred f(a)
168185
@test x * y a
169186
@test arguments(x, 2) isa Eye
170187
@test arguments(y, 2) isa Eye
@@ -180,21 +197,21 @@ end
180197

181198
for f in (svd_compact, svd_full)
182199
a = Eye(3) randn(3, 3)
183-
u, s, v = f(a)
200+
u, s, v = @constinferred f(a)
184201
@test u * s * v a
185202
@test arguments(u, 1) isa Eye
186203
@test arguments(s, 1) isa Eye
187204
@test arguments(v, 1) isa Eye
188205

189206
a = randn(3, 3) Eye(3)
190-
u, s, v = f(a)
207+
u, s, v = @constinferred f(a)
191208
@test u * s * v a
192209
@test arguments(u, 2) isa Eye
193210
@test arguments(s, 2) isa Eye
194211
@test arguments(v, 2) isa Eye
195212

196213
a = Eye(3) Eye(3)
197-
u, s, v = f(a)
214+
u, s, v = @constinferred f(a)
198215
@test u * s * v a
199216
@test arguments(u, 1) isa Eye
200217
@test arguments(s, 1) isa Eye
@@ -203,4 +220,52 @@ end
203220
@test arguments(s, 2) isa Eye
204221
@test arguments(v, 2) isa Eye
205222
end
223+
224+
a = Eye(3) randn(3, 3)
225+
d = @constinferred svd_vals(a)
226+
d′ = svd_vals(Matrix(a))
227+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
228+
@test arguments(d, 1) isa Ones
229+
@test arguments(d, 2) svd_vals(arguments(a, 2))
230+
231+
a = randn(3, 3) Eye(3)
232+
d = @constinferred svd_vals(a)
233+
d′ = svd_vals(Matrix(a))
234+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
235+
@test arguments(d, 2) isa Ones
236+
@test arguments(d, 1) svd_vals(arguments(a, 1))
237+
238+
a = Eye(3) Eye(3)
239+
d = @constinferred svd_vals(a)
240+
@test d == Ones(3) Ones(3)
241+
@test arguments(d, 1) isa Ones
242+
@test arguments(d, 2) isa Ones
243+
244+
# left_null
245+
a = Eye(3) randn(3, 3)
246+
n = @constinferred left_null(a)
247+
@test norm(n' * a) 0
248+
@test arguments(n, 1) isa Eye
249+
250+
a = randn(3, 3) Eye(3)
251+
n = @constinferred left_null(a)
252+
@test norm(n' * a) 0
253+
@test arguments(n, 2) isa Eye
254+
255+
a = Eye(3) Eye(3)
256+
@test_throws MethodError left_null(a)
257+
258+
# right_null
259+
a = Eye(3) randn(3, 3)
260+
n = @constinferred right_null(a)
261+
@test norm(a * n') 0
262+
@test arguments(n, 1) isa Eye
263+
264+
a = randn(3, 3) Eye(3)
265+
n = @constinferred right_null(a)
266+
@test norm(a * n') 0
267+
@test arguments(n, 2) isa Eye
268+
269+
a = Eye(3) Eye(3)
270+
@test_throws MethodError right_null(a)
206271
end

0 commit comments

Comments
 (0)