Skip to content

Commit 4bac6c8

Browse files
committed
Rework left_orth
1 parent e7d8788 commit 4bac6c8

File tree

2 files changed

+96
-72
lines changed

2 files changed

+96
-72
lines changed

src/tensors/factorizations.jl

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -326,51 +326,54 @@ end
326326
function leftorth!(t::TensorMap{<:RealOrComplexFloat};
327327
alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(),
328328
atol::Real=zero(float(real(scalartype(t)))),
329-
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
330-
eps(real(float(one(scalartype(t))))) * iszero(atol))
329+
rtol::Real=(alg (SVD(), SDD())) ?
330+
zero(float(real(scalartype(t)))) :
331+
eps(real(float(one(scalartype(t))))) *
332+
iszero(atol))
331333
InnerProductStyle(t) === EuclideanInnerProduct() ||
332334
throw_invalid_innerproduct(:leftorth!)
333-
334-
VC = MatrixAlgebraKit.initialize_output(left_orth!, t)
335-
336-
if alg isa QR
337-
return left_orth!(t, VC; kind=:qr, atol, rtol)
338-
elseif alg isa QRpos
339-
return left_orth!(t, VC; kind=:qrpos, atol, rtol)
340-
elseif alg isa SDD
341-
return left_orth!(t, VC; kind=:svd, atol, rtol)
342-
elseif alg isa Polar
343-
return left_orth!(t, VC; kind=:polar, atol, rtol)
344-
elseif alg isa SVD
345-
kind = :svd
346-
if iszero(atol) && iszero(rtol)
347-
alg′ = LAPACK_QRIteration()
348-
return left_orth!(t, VC; kind,
349-
alg=BlockAlgorithm(alg′, default_blockscheduler(t)),
350-
atol, rtol)
351-
else
352-
trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol)
353-
svd_alg = LAPACK_QRIteration()
354-
scheduler = default_blockscheduler(t)
355-
alg′ = MatrixAlgebraKit.TruncatedAlgorithm(BlockAlgorithm(svd_alg, scheduler),
356-
trunc)
357-
return left_orth!(t, VC; kind, alg=alg′, atol, rtol)
358-
end
359-
elseif alg isa QL
360-
_reverse!(t; dims=2)
361-
Q, R = left_orth!(t, VC; kind=:qr, atol, rtol)
362-
_reverse!(Q; dims=2)
363-
_reverse!(R)
364-
return Q, R
365-
elseif alg isa QLpos
366-
_reverse!(t; dims=2)
367-
Q, R = left_orth!(t, VC; kind=:qrpos, atol, rtol)
368-
_reverse!(Q; dims=2)
369-
_reverse!(R)
370-
return Q, R
335+
if alg == SVD() || alg == SDD()
336+
return _leftorth!(t, alg; atol, rtol)
337+
else
338+
(iszero(atol) && iszero(rtol)) ||
339+
throw(ArgumentError("`leftorth!` with nonzero atol or rtol requires SVD or SDD algorithm"))
340+
return _leftorth!(t, alg)
371341
end
342+
end
372343

373-
throw(ArgumentError("Algorithm $alg not implemented for leftorth!"))
344+
# this promotes the algorithm to a positional argument for type stability reasons
345+
# since polar has different number of output legs
346+
# TODO: this seems like duplication from MatrixAlgebraKit.left_orth!, but that function
347+
# only has its logic with the output already specified, which breaks for polar
348+
function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD}; atol::Real,
349+
rtol::Real)
350+
alg′ = alg == SVD() ? MatrixAlgebraKit.LAPACK_QRIteration() :
351+
MatrixAlgebraKit.LAPACK_DivideAndConquer()
352+
alg_svd = BlockAlgorithm(alg′, default_blockscheduler(t))
353+
if iszero(atol) && iszero(rtol)
354+
U, S, Vᴴ = svd_compact!(t, alg_svd)
355+
return U, lmul!(S, Vᴴ)
356+
else
357+
trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol)
358+
alg_svd = MatrixAlgebraKit.select_algorithm(svd_trunc!, t; trunc,
359+
alg=alg_svd)
360+
361+
U, S, Vᴴ = svd_trunc!(t, alg_svd)
362+
return U, lmul!(S, Vᴴ)
363+
end
364+
end
365+
function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{QR,QRpos})
366+
return qr_compact!(t; positive=alg == QRpos())
367+
end
368+
function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{QL,QLpos})
369+
_reverse!(t; dims=2)
370+
Q, R = qr_compact!(t; positive=alg == QLpos())
371+
_reverse!(Q; dims=2)
372+
_reverse!(R)
373+
return Q, R
374+
end
375+
function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, ::Polar)
376+
return MatrixAlgebraKit.left_polar!(t)
374377
end
375378

376379
function leftnull!(t::TensorMap{<:RealOrComplexFloat};

src/tensors/matrixalgebrakit.jl

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ macro check_eltype(x, y, f=:identity, g=:eltype)
2121
return esc(:($g($x) == $f($g($y)) || throw(ArgumentError($msg))))
2222
end
2323

24+
function MatrixAlgebraKit._select_algorithm(_, ::AbstractTensorMap,
25+
alg::MatrixAlgebraKit.AbstractAlgorithm)
26+
return alg
27+
end
28+
function MatrixAlgebraKit._select_algorithm(f, t::AbstractTensorMap, alg::NamedTuple)
29+
return MatrixAlgebraKit.select_algorithm(f, t; alg...)
30+
end
31+
32+
function _select_truncation(f, ::AbstractTensorMap,
33+
trunc::MatrixAlgebraKit.TruncationStrategy)
34+
return trunc
35+
end
2436
# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
2537
# T = scalartype(t)
2638
# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
@@ -76,7 +88,7 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTens
7688
::MatrixAlgebraKit.AbstractAlgorithm)
7789
V_cod = fuse(codomain(t))
7890
V_dom = fuse(domain(t))
79-
U = similar(t, domain(t) V_cod)
91+
U = similar(t, codomain(t) V_cod)
8092
S = similar(t, real(scalartype(t)), V_cod V_dom)
8193
Vᴴ = similar(t, V_dom domain(t))
8294
return U, S, Vᴴ
@@ -476,18 +488,19 @@ function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P))
476488
@check_eltype P t
477489

478490
# space checks
479-
space(W) == (codomain(t) fuse(domain(t))) ||
491+
space(W) == space(t) ||
480492
throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`"))
481-
space(P) == (fuse(domain(t)) domain(t)) ||
493+
space(P) == (domain(t) domain(t)) ||
482494
throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`"))
483495

484496
return nothing
485497
end
486498

487499
# TODO: do we really not want to fuse the spaces?
488-
function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap)
489-
W = similar(t, codomain(t) fuse(domain(t)))
490-
P = similar(t, fuse(domain(t)) domain(t))
500+
function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap,
501+
::MatrixAlgebraKit.AbstractAlgorithm)
502+
W = similar(t, space(t))
503+
P = similar(t, domain(t) domain(t))
491504
return W, P
492505
end
493506

@@ -558,40 +571,48 @@ function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTe
558571
return C, Vᴴ
559572
end
560573

561-
function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC; kwargs...)
562-
MatrixAlgebraKit.check_input(left_orth!, t, VC)
563-
atol = get(kwargs, :atol, 0)
564-
rtol = get(kwargs, :rtol, 0)
565-
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd)
566-
567-
if !(iszero(atol) && iszero(rtol)) && kind != :svd
568-
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
574+
function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC;
575+
trunc=nothing,
576+
kind=isnothing(trunc) ?
577+
:qr : :svd,
578+
alg_qr=(; positive=true),
579+
alg_polar=(;),
580+
alg_svd=(;))
581+
if !isnothing(trunc) && kind != :svd
582+
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
569583
end
570584

571585
if kind == :qr
572-
alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(qr_compact!, t))
573-
return qr_compact!(t, VC, alg)
574-
elseif kind == :qrpos
575-
alg = get(kwargs, :alg,
576-
MatrixAlgebraKit.select_algorithm(qr_compact!, t; positive=true))
577-
return qr_compact!(t, VC, alg)
578-
elseif kind == :polar
579-
alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(left_polar!, t))
580-
return left_polar!(t, VC, alg)
581-
elseif kind == :svd && iszero(atol) && iszero(rtol)
582-
alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(svd_compact!, t))
586+
alg_qr′ = MatrixAlgebraKit._select_algorithm(qr_compact!, t, alg_qr)
587+
return qr_compact!(t, VC, alg_qr′)
588+
end
589+
590+
if kind == :polar
591+
alg_polar′ = MatrixAlgebraKit._select_algorithm(left_polar!, t, alg_polar)
592+
return left_polar!(t, VC, alg_polar′)
593+
end
594+
595+
if kind == :svd && isnothing(trunc)
596+
alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_compact!, t, alg_svd)
583597
V, C = VC
584598
S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) codomain(C))
585-
U, S, Vᴴ = svd_compact!(t, (V, S, C), alg)
599+
U, S, Vᴴ = svd_compact!(t, (V, S, C), alg_svd′)
586600
return U, lmul!(S, Vᴴ)
587-
elseif kind == :svd
588-
alg_svd = MatrixAlgebraKit.select_algorithm(svd_compact!, t)
589-
trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol)
590-
alg = get(kwargs, :alg, MatrixAlgebraKit.TruncatedAlgorithm(alg_svd, trunc))
601+
end
602+
603+
if kind == :svd
604+
alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_compact!, t, alg_svd)
605+
alg_svd_trunc = MatrixAlgebraKit.select_algorithm(svd_trunc!, t; trunc,
606+
alg=alg_svd′)
591607
V, C = VC
592608
S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) codomain(C))
593-
U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg)
609+
U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg_svd_trunc)
594610
return U, lmul!(S, Vᴴ)
611+
end
612+
613+
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
614+
end
615+
595616
else
596617
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
597618
end

0 commit comments

Comments
 (0)