Skip to content

Commit 24286ec

Browse files
committed
More factorizations and tests
1 parent 2996541 commit 24286ec

File tree

2 files changed

+145
-83
lines changed

2 files changed

+145
-83
lines changed

src/KroneckerArrays.jl

Lines changed: 85 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,10 @@ using MatrixAlgebraKit:
381381
TruncationStrategy,
382382
default_eig_algorithm,
383383
default_eigh_algorithm,
384+
default_lq_algorithm,
385+
default_polar_algorithm,
384386
default_qr_algorithm,
387+
default_svd_algorithm,
385388
eig_full!,
386389
eig_trunc!,
387390
eig_vals!,
@@ -390,107 +393,107 @@ using MatrixAlgebraKit:
390393
eigh_vals!,
391394
initialize_output,
392395
left_null!,
396+
left_orth!,
397+
left_polar!,
398+
lq_compact!,
399+
lq_full!,
400+
qr_compact!,
401+
qr_full!,
393402
right_null!,
403+
right_orth!,
404+
right_polar!,
405+
svd_compact!,
406+
svd_full!,
407+
svd_trunc!,
408+
svd_vals!,
394409
truncate!
395410

396411
struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
397412
a::A
398413
b::B
399414
end
400415

401-
function MatrixAlgebraKit.default_eig_algorithm(a::KroneckerMatrix)
402-
return KroneckerAlgorithm(default_eig_algorithm(a.a), default_eig_algorithm(a.b))
403-
end
404-
function MatrixAlgebraKit.initialize_output(
405-
f::typeof(eig_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm
406-
)
407-
return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b)
408-
end
409-
function MatrixAlgebraKit.eig_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
410-
eig_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a)
411-
eig_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b)
412-
return F
413-
end
414-
415-
function MatrixAlgebraKit.truncate!(
416-
::typeof(eig_trunc!),
417-
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
418-
strategy::TruncationStrategy,
419-
)
420-
return throw(MethodError(truncate!, (eig_trunc!, (D, V), strategy)))
421-
end
422-
423-
function MatrixAlgebraKit.initialize_output(
424-
f::typeof(eig_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm
425-
)
426-
return initialize_output(f, a.a, alg.a) initialize_output(f, a.b, alg.b)
427-
end
428-
function MatrixAlgebraKit.eig_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
429-
eig_vals!(a.a, F.a, alg.a)
430-
eig_vals!(a.b, F.b, alg.b)
431-
return F
432-
end
433-
434-
function MatrixAlgebraKit.default_eigh_algorithm(a::KroneckerMatrix)
435-
return KroneckerAlgorithm(default_eigh_algorithm(a.a), default_eigh_algorithm(a.b))
436-
end
437-
function MatrixAlgebraKit.initialize_output(
438-
f::typeof(eigh_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm
439-
)
440-
return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b)
441-
end
442-
function MatrixAlgebraKit.eigh_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
443-
eigh_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a)
444-
eigh_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b)
445-
return F
416+
for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
417+
ff = Symbol("default_", f, "_algorithm")
418+
@eval begin
419+
function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...)
420+
return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...))
421+
end
422+
end
446423
end
447424

448-
function MatrixAlgebraKit.truncate!(
449-
::typeof(eigh_trunc!),
450-
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
451-
strategy::TruncationStrategy,
425+
for f in (
426+
:eig_full!,
427+
:eigh_full!,
428+
:qr_compact!,
429+
:qr_full!,
430+
:left_polar!,
431+
:lq_compact!,
432+
:lq_full!,
433+
:right_polar!,
434+
:svd_compact!,
435+
:svd_full!,
452436
)
453-
return throw(MethodError(truncate!, (eigh_trunc!, (D, V), strategy)))
437+
@eval begin
438+
function MatrixAlgebraKit.initialize_output(
439+
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
440+
)
441+
return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b)
442+
end
443+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...)
444+
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...)
445+
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...)
446+
return F
447+
end
448+
end
454449
end
455450

456-
function MatrixAlgebraKit.initialize_output(
457-
f::typeof(eigh_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm
458-
)
459-
return initialize_output(f, a.a, alg.a) initialize_output(f, a.b, alg.b)
460-
end
461-
function MatrixAlgebraKit.eigh_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
462-
eigh_vals!(a.a, F.a, alg.a)
463-
eigh_vals!(a.b, F.b, alg.b)
464-
return F
451+
for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
452+
@eval begin
453+
function MatrixAlgebraKit.initialize_output(
454+
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
455+
)
456+
return initialize_output($f, a.a, alg.a) initialize_output($f, a.b, alg.b)
457+
end
458+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
459+
$f(a.a, F.a, alg.a)
460+
$f(a.b, F.b, alg.b)
461+
return F
462+
end
463+
end
465464
end
466465

467-
function MatrixAlgebraKit.default_qr_algorithm(a::KroneckerMatrix; kwargs...)
468-
return KroneckerAlgorithm(
469-
default_qr_algorithm(a.a; kwargs...), default_qr_algorithm(a.b; kwargs...)
470-
)
471-
end
472-
function MatrixAlgebraKit.default_lq_algorithm(a::KroneckerMatrix; kwargs...)
473-
return KroneckerAlgorithm(
474-
default_lq_algorithm(a.a; kwargs...), default_lq_algorithm(a.b; kwargs...)
475-
)
466+
for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
467+
@eval begin
468+
function MatrixAlgebraKit.truncate!(
469+
::typeof($f),
470+
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
471+
strategy::TruncationStrategy,
472+
)
473+
return throw(MethodError(truncate!, ($f, (D, V), strategy)))
474+
end
475+
end
476476
end
477477

478-
function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::KroneckerMatrix)
479-
return initialize_output(f, a.a) initialize_output(f, a.b)
480-
end
481-
function MatrixAlgebraKit.left_null!(a::KroneckerMatrix, F; kwargs...)
482-
left_null!(a.a, F.a; kwargs...)
483-
left_null!(a.b, F.b; kwargs...)
484-
return F
478+
for f in (:left_orth!, :right_orth!)
479+
@eval begin
480+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
481+
return initialize_output($f, a.a) .⊗ initialize_output($f, a.b)
482+
end
483+
end
485484
end
486485

487-
function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::KroneckerMatrix)
488-
return initialize_output(f, a.a) initialize_output(f, a.b)
489-
end
490-
function MatrixAlgebraKit.right_null!(a::KroneckerMatrix, F; kwargs...)
491-
right_null!(a.a, F.a; kwargs...)
492-
right_null!(a.b, F.b; kwargs...)
493-
return F
486+
for f in (:left_null!, :right_null!)
487+
@eval begin
488+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
489+
return initialize_output($f, a.a) initialize_output($f, a.b)
490+
end
491+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...)
492+
$f(a.a, F.a; kwargs...)
493+
$f(a.b, F.b; kwargs...)
494+
return F
495+
end
496+
end
494497
end
495498

496499
end

test/test_matrixalgebrakit.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using KroneckerArrays:
2-
using LinearAlgebra: Hermitian, diag, norm
2+
using LinearAlgebra: Hermitian, I, diag, norm
33
using MatrixAlgebraKit:
44
eig_full,
55
eig_trunc,
@@ -48,11 +48,70 @@ using Test: @test, @test_throws, @testset
4848
d = eigh_vals(a)
4949
@test d diag(eigh_full(a)[1])
5050

51+
a = randn(elt, 2, 2) randn(elt, 3, 3)
52+
u, c = qr_compact(a)
53+
@test u * c a
54+
@test collect(u'u) I
55+
56+
a = randn(elt, 2, 2) randn(elt, 3, 3)
57+
u, c = qr_full(a)
58+
@test u * c a
59+
@test collect(u'u) I
60+
61+
a = randn(elt, 2, 2) randn(elt, 3, 3)
62+
c, u = lq_compact(a)
63+
@test c * u a
64+
@test collect(u * u') I
65+
66+
a = randn(elt, 2, 2) randn(elt, 3, 3)
67+
c, u = lq_full(a)
68+
@test c * u a
69+
@test collect(u * u') I
70+
5171
a = randn(elt, 3, 2) randn(elt, 4, 3)
5272
n = left_null(a)
5373
@test norm(n' * a) 0 atol = eps(real(elt))
5474

5575
a = randn(elt, 2, 3) randn(elt, 3, 4)
5676
n = right_null(a)
5777
@test norm(a * n') 0 atol = eps(real(elt))
78+
79+
a = randn(elt, 2, 2) randn(elt, 3, 3)
80+
u, c = left_orth(a)
81+
@test u * c a
82+
@test collect(u'u) I
83+
84+
a = randn(elt, 2, 2) randn(elt, 3, 3)
85+
c, u = right_orth(a)
86+
@test c * u a
87+
@test collect(u * u') I
88+
89+
a = randn(elt, 2, 2) randn(elt, 3, 3)
90+
u, c = left_polar(a)
91+
@test u * c a
92+
@test collect(u'u) I
93+
94+
a = randn(elt, 2, 2) randn(elt, 3, 3)
95+
c, u = right_polar(a)
96+
@test c * u a
97+
@test collect(u * u') I
98+
99+
a = randn(elt, 2, 2) randn(elt, 3, 3)
100+
u, s, v = svd_compact(a)
101+
@test u * s * v a
102+
@test collect(u'u) I
103+
@test collect(v * v') I
104+
105+
a = randn(elt, 2, 2) randn(elt, 3, 3)
106+
u, s, v = svd_full(a)
107+
@test u * s * v a
108+
@test collect(u'u) I
109+
@test collect(v * v') I
110+
111+
a = randn(elt, 2, 2) randn(elt, 3, 3)
112+
@test_throws MethodError svd_trunc(a)
113+
114+
a = randn(elt, 2, 2) randn(elt, 3, 3)
115+
s = svd_vals(a)
116+
@test s diag(svd_compact(a)[2])
58117
end

0 commit comments

Comments
 (0)