Skip to content

Commit b3f197a

Browse files
committed
Improve element types in SVD and eig, better tests
1 parent 766b435 commit b3f197a

File tree

2 files changed

+150
-82
lines changed

2 files changed

+150
-82
lines changed

src/KroneckerArrays.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,8 +1070,6 @@ for f in [:left_null!, :right_null!]
10701070
end
10711071
end
10721072
for f in [
1073-
:eig_full!,
1074-
:eigh_full!,
10751073
:qr_compact!,
10761074
:qr_full!,
10771075
:left_orth!,
@@ -1086,10 +1084,14 @@ for f in [
10861084
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a)
10871085
end
10881086
end
1087+
_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a))
1088+
_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a))
1089+
_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a)
1090+
_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a)
10891091
for f in [:svd_compact!, :svd_full!]
10901092
@eval begin
1091-
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a)
1092-
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a)
1093+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a)
1094+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a)
10931095
end
10941096
end
10951097

@@ -1173,10 +1175,12 @@ function MatrixAlgebraKit.right_null!(
11731175
return throw(MethodError(right_null!, (a, F)))
11741176
end
11751177

1176-
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
1178+
_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a)
1179+
_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a)
1180+
for f in [:eigh_vals!, svd_vals!]
11771181
@eval begin
1178-
_initialize_output_squareeye(::typeof($f), a::SquareEye) = parent(a)
1179-
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = parent(a)
1182+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a))
1183+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a))
11801184
end
11811185
end
11821186

test/test_matrixalgebrakit.jl

Lines changed: 139 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@ herm(a) = parent(hermitianpart(a))
4444
a = herm(randn(elt, 2, 2)) herm(randn(elt, 3, 3))
4545
d, v = eigh_full(a)
4646
@test a * v v * d
47+
@test eltype(d) === real(elt)
48+
@test eltype(v) === elt
4749

4850
a = herm(randn(elt, 2, 2)) herm(randn(elt, 3, 3))
4951
@test_throws MethodError eigh_trunc(a)
5052

5153
a = herm(randn(elt, 2, 2)) herm(randn(elt, 3, 3))
5254
d = eigh_vals(a)
5355
@test d diag(eigh_full(a)[1])
56+
@test eltype(d) === real(elt)
5457

5558
a = randn(elt, 2, 2) randn(elt, 3, 3)
5659
u, c = qr_compact(a)
@@ -103,12 +106,18 @@ herm(a) = parent(hermitianpart(a))
103106
a = randn(elt, 2, 2) randn(elt, 3, 3)
104107
u, s, v = svd_compact(a)
105108
@test u * s * v a
109+
@test eltype(u) === elt
110+
@test eltype(s) === real(elt)
111+
@test eltype(v) === elt
106112
@test collect(u'u) I
107113
@test collect(v * v') I
108114

109115
a = randn(elt, 2, 2) randn(elt, 3, 3)
110116
u, s, v = svd_full(a)
111117
@test u * s * v a
118+
@test eltype(u) === elt
119+
@test eltype(s) === real(elt)
120+
@test eltype(v) === elt
112121
@test collect(u'u) I
113122
@test collect(v * v') I
114123

@@ -121,26 +130,48 @@ herm(a) = parent(hermitianpart(a))
121130
end
122131

123132
@testset "MatrixAlgebraKit + Eye" begin
124-
for f in (eig_full, eigh_full)
125-
a = Eye(3) parent(hermitianpart(randn(3, 3)))
126-
d, v = @constinferred f(a)
133+
for elt in (Float32, ComplexF32)
134+
a = Eye{elt}(3) randn(elt, 3, 3)
135+
d, v = @constinferred eig_full(a)
127136
@test a * v v * d
128-
@test arguments(d, 1) isa Eye
129-
@test arguments(v, 1) isa Eye
137+
@test arguments(d, 1) isa Eye{complex(elt)}
138+
@test arguments(v, 1) isa Eye{complex(elt)}
130139

131-
a = parent(hermitianpart(randn(3, 3))) Eye(3)
132-
d, v = @constinferred f(a)
140+
a = parent(hermitianpart(randn(elt, 3, 3))) Eye{elt}(3)
141+
d, v = @constinferred eig_full(a)
133142
@test a * v v * d
134-
@test arguments(d, 2) isa Eye
135-
@test arguments(v, 2) isa Eye
143+
@test arguments(d, 2) isa Eye{complex(elt)}
144+
@test arguments(v, 2) isa Eye{complex(elt)}
136145

137-
a = Eye(3) Eye(3)
138-
d, v = @constinferred f(a)
146+
a = Eye{elt}(3) Eye{elt}(3)
147+
d, v = @constinferred eig_full(a)
139148
@test a * v v * d
140-
@test arguments(d, 1) isa Eye
141-
@test arguments(d, 2) isa Eye
142-
@test arguments(v, 1) isa Eye
143-
@test arguments(v, 2) isa Eye
149+
@test arguments(d, 1) isa Eye{complex(elt)}
150+
@test arguments(d, 2) isa Eye{complex(elt)}
151+
@test arguments(v, 1) isa Eye{complex(elt)}
152+
@test arguments(v, 2) isa Eye{complex(elt)}
153+
end
154+
155+
for elt in (Float32, ComplexF32)
156+
a = Eye{elt}(3) parent(hermitianpart(randn(elt, 3, 3)))
157+
d, v = @constinferred eigh_full(a)
158+
@test a * v v * d
159+
@test arguments(d, 1) isa Eye{real(elt)}
160+
@test arguments(v, 1) isa Eye{elt}
161+
162+
a = parent(hermitianpart(randn(elt, 3, 3))) Eye{elt}(3)
163+
d, v = @constinferred eigh_full(a)
164+
@test a * v v * d
165+
@test arguments(d, 2) isa Eye{real(elt)}
166+
@test arguments(v, 2) isa Eye{elt}
167+
168+
a = Eye{elt}(3) Eye{elt}(3)
169+
d, v = @constinferred eigh_full(a)
170+
@test a * v v * d
171+
@test arguments(d, 1) isa Eye{real(elt)}
172+
@test arguments(d, 2) isa Eye{real(elt)}
173+
@test arguments(v, 1) isa Eye{elt}
174+
@test arguments(v, 2) isa Eye{elt}
144175
end
145176

146177
for f in (eig_trunc, eigh_trunc)
@@ -211,77 +242,110 @@ end
211242
end
212243

213244
for f in (svd_compact, svd_full)
214-
a = Eye(3) randn(3, 3)
215-
u, s, v = @constinferred f(a)
216-
@test u * s * v a
217-
@test arguments(u, 1) isa Eye
218-
@test arguments(s, 1) isa Eye
219-
@test arguments(v, 1) isa Eye
220-
221-
a = randn(3, 3) Eye(3)
222-
u, s, v = @constinferred f(a)
223-
@test u * s * v a
224-
@test arguments(u, 2) isa Eye
225-
@test arguments(s, 2) isa Eye
226-
@test arguments(v, 2) isa Eye
227-
228-
a = Eye(3) Eye(3)
229-
u, s, v = @constinferred f(a)
230-
@test u * s * v a
231-
@test arguments(u, 1) isa Eye
232-
@test arguments(s, 1) isa Eye
233-
@test arguments(v, 1) isa Eye
234-
@test arguments(u, 2) isa Eye
235-
@test arguments(s, 2) isa Eye
236-
@test arguments(v, 2) isa Eye
245+
for elt in (Float32, ComplexF32)
246+
a = Eye{elt}(3) randn(elt, 3, 3)
247+
u, s, v = @constinferred f(a)
248+
@test u * s * v a
249+
@test eltype(u) === elt
250+
@test eltype(s) === real(elt)
251+
@test eltype(v) === elt
252+
@test arguments(u, 1) isa Eye{elt}
253+
@test arguments(s, 1) isa Eye{real(elt)}
254+
@test arguments(v, 1) isa Eye{elt}
255+
256+
a = randn(elt, 3, 3) Eye{elt}(3)
257+
u, s, v = @constinferred f(a)
258+
@test u * s * v a
259+
@test eltype(u) === elt
260+
@test eltype(s) === real(elt)
261+
@test eltype(v) === elt
262+
@test arguments(u, 2) isa Eye{elt}
263+
@test arguments(s, 2) isa Eye{real(elt)}
264+
@test arguments(v, 2) isa Eye{elt}
265+
266+
a = Eye{elt}(3) Eye{elt}(3)
267+
u, s, v = @constinferred f(a)
268+
@test u * s * v a
269+
@test eltype(u) === elt
270+
@test eltype(s) === real(elt)
271+
@test eltype(v) === elt
272+
@test arguments(u, 1) isa Eye{elt}
273+
@test arguments(s, 1) isa Eye{real(elt)}
274+
@test arguments(v, 1) isa Eye{elt}
275+
@test arguments(u, 2) isa Eye{elt}
276+
@test arguments(s, 2) isa Eye{real(elt)}
277+
@test arguments(v, 2) isa Eye{elt}
278+
end
237279
end
238280

239281
# svd_trunc
240-
a = Eye(3) randn(3, 3)
241-
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
242-
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
243-
@test Matrix(u * s * v) u′ * s′ * v′
244-
@test arguments(u, 1) isa Eye
245-
@test arguments(s, 1) isa Eye
246-
@test arguments(v, 1) isa Eye
247-
@test size(u) == (9, 6)
248-
@test size(s) == (6, 6)
249-
@test size(v) == (6, 9)
282+
for elt in (Float32, ComplexF32)
283+
a = Eye{elt}(3) randn(elt, 3, 3)
284+
# TODO: Type inference is broken for `svd_trunc`,
285+
# look into fixing it.
286+
# u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
287+
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
288+
@test eltype(u) === elt
289+
@test eltype(s) === real(elt)
290+
@test eltype(v) === elt
291+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
292+
@test Matrix(u * s * v) u′ * s′ * v′
293+
@test arguments(u, 1) isa Eye{elt}
294+
@test arguments(s, 1) isa Eye{real(elt)}
295+
@test arguments(v, 1) isa Eye{elt}
296+
@test size(u) == (9, 6)
297+
@test size(s) == (6, 6)
298+
@test size(v) == (6, 9)
299+
end
250300

251-
a = randn(3, 3) Eye(3)
252-
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
253-
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
254-
@test Matrix(u * s * v) u′ * s′ * v′
255-
@test arguments(u, 2) isa Eye
256-
@test arguments(s, 2) isa Eye
257-
@test arguments(v, 2) isa Eye
258-
@test size(u) == (9, 6)
259-
@test size(s) == (6, 6)
260-
@test size(v) == (6, 9)
301+
for elt in (Float32, ComplexF32)
302+
a = randn(elt, 3, 3) Eye{elt}(3)
303+
# TODO: Type inference is broken for `svd_trunc`,
304+
# look into fixing it.
305+
# u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
306+
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
307+
@test eltype(u) === elt
308+
@test eltype(s) === real(elt)
309+
@test eltype(v) === elt
310+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
311+
@test Matrix(u * s * v) u′ * s′ * v′
312+
@test arguments(u, 2) isa Eye{elt}
313+
@test arguments(s, 2) isa Eye{real(elt)}
314+
@test arguments(v, 2) isa Eye{elt}
315+
@test size(u) == (9, 6)
316+
@test size(s) == (6, 6)
317+
@test size(v) == (6, 9)
318+
end
261319

262320
a = Eye(3) Eye(3)
263321
@test_throws ArgumentError svd_trunc(a)
264322

265323
# svd_vals
266-
a = Eye(3) randn(3, 3)
267-
d = @constinferred svd_vals(a)
268-
d′ = svd_vals(Matrix(a))
269-
@test sort(Vector(d); by=abs) sort(d′; by=abs)
270-
@test arguments(d, 1) isa Ones
271-
@test arguments(d, 2) svd_vals(arguments(a, 2))
324+
for elt in (Float32, ComplexF32)
325+
a = Eye{elt}(3) randn(elt, 3, 3)
326+
d = @constinferred svd_vals(a)
327+
d′ = svd_vals(Matrix(a))
328+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
329+
@test arguments(d, 1) isa Ones{real(elt)}
330+
@test arguments(d, 2) svd_vals(arguments(a, 2))
331+
end
272332

273-
a = randn(3, 3) Eye(3)
274-
d = @constinferred svd_vals(a)
275-
d′ = svd_vals(Matrix(a))
276-
@test sort(Vector(d); by=abs) sort(d′; by=abs)
277-
@test arguments(d, 2) isa Ones
278-
@test arguments(d, 1) svd_vals(arguments(a, 1))
333+
for elt in (Float32, ComplexF32)
334+
a = randn(elt, 3, 3) Eye{elt}(3)
335+
d = @constinferred svd_vals(a)
336+
d′ = svd_vals(Matrix(a))
337+
@test sort(Vector(d); by=abs) sort(d′; by=abs)
338+
@test arguments(d, 2) isa Ones{real(elt)}
339+
@test arguments(d, 1) svd_vals(arguments(a, 1))
340+
end
279341

280-
a = Eye(3) Eye(3)
281-
d = @constinferred svd_vals(a)
282-
@test d == Ones(3) Ones(3)
283-
@test arguments(d, 1) isa Ones
284-
@test arguments(d, 2) isa Ones
342+
for elt in (Float32, ComplexF32)
343+
a = Eye{elt}(3) Eye{elt}(3)
344+
d = @constinferred svd_vals(a)
345+
@test d == Ones(3) Ones(3)
346+
@test arguments(d, 1) isa Ones{real(elt)}
347+
@test arguments(d, 2) isa Ones{real(elt)}
348+
end
285349

286350
# left_null
287351
a = Eye(3) randn(3, 3)

0 commit comments

Comments
 (0)