@@ -21,6 +21,18 @@ macro check_eltype(x, y, f=:identity, g=:eltype)
2121 return esc(:($ g($ x) == $ f($ g($ y)) || throw(ArgumentError($ msg))))
2222end
2323
24+ function MatrixAlgebraKit. _select_algorithm(_, :: AbstractTensorMap ,
25+ alg:: MatrixAlgebraKit.AbstractAlgorithm )
26+ return alg
27+ end
28+ function MatrixAlgebraKit. _select_algorithm(f, t:: AbstractTensorMap , alg:: NamedTuple )
29+ return MatrixAlgebraKit. select_algorithm(f, t; alg... )
30+ end
31+
32+ function _select_truncation(f, :: AbstractTensorMap ,
33+ trunc:: MatrixAlgebraKit.TruncationStrategy )
34+ return trunc
35+ end
2436# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
2537# T = scalartype(t)
2638# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
@@ -76,7 +88,7 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTens
7688 :: MatrixAlgebraKit.AbstractAlgorithm )
7789 V_cod = fuse(codomain(t))
7890 V_dom = fuse(domain(t))
79- U = similar(t, domain (t) ← V_cod)
91+ U = similar(t, codomain (t) ← V_cod)
8092 S = similar(t, real(scalartype(t)), V_cod ← V_dom)
8193 Vᴴ = similar(t, V_dom ← domain(t))
8294 return U, S, Vᴴ
@@ -476,18 +488,19 @@ function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P))
476488 @check_eltype P t
477489
478490 # space checks
479- space(W) == (codomain(t) ← fuse(domain(t)) ) ||
491+ space(W) == space(t ) ||
480492 throw(SpaceMismatch(" `left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`" ))
481- space(P) == (fuse( domain(t) ) ← domain(t)) ||
493+ space(P) == (domain(t) ← domain(t)) ||
482494 throw(SpaceMismatch(" `left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`" ))
483495
484496 return nothing
485497end
486498
487499# TODO : do we really not want to fuse the spaces?
488- function MatrixAlgebraKit. initialize_output(:: typeof (left_polar!), t:: AbstractTensorMap )
489- W = similar(t, codomain(t) ← fuse(domain(t)))
490- P = similar(t, fuse(domain(t)) ← domain(t))
500+ function MatrixAlgebraKit. initialize_output(:: typeof (left_polar!), t:: AbstractTensorMap ,
501+ :: MatrixAlgebraKit.AbstractAlgorithm )
502+ W = similar(t, space(t))
503+ P = similar(t, domain(t) ← domain(t))
491504 return W, P
492505end
493506
@@ -558,40 +571,48 @@ function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTe
558571 return C, Vᴴ
559572end
560573
561- function MatrixAlgebraKit. left_orth!(t:: AbstractTensorMap , VC; kwargs... )
562- MatrixAlgebraKit. check_input(left_orth!, t, VC)
563- atol = get(kwargs, :atol, 0 )
564- rtol = get(kwargs, :rtol, 0 )
565- kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd)
566-
567- if ! (iszero(atol) && iszero(rtol)) && kind != :svd
568- throw(ArgumentError(" nonzero tolerance not supported for left_orth with kind=$kind " ))
574+ function MatrixAlgebraKit. left_orth!(t:: AbstractTensorMap , VC;
575+ trunc= nothing ,
576+ kind= isnothing(trunc) ?
577+ :qr : :svd,
578+ alg_qr= (; positive= true ),
579+ alg_polar= (;),
580+ alg_svd= (;))
581+ if ! isnothing(trunc) && kind != :svd
582+ throw(ArgumentError(" truncation not supported for left_orth with kind=$kind " ))
569583 end
570584
571585 if kind == :qr
572- alg = get(kwargs, :alg, MatrixAlgebraKit. select_algorithm (qr_compact!, t) )
573- return qr_compact!(t, VC, alg )
574- elseif kind == :qrpos
575- alg = get(kwargs, :alg,
576- MatrixAlgebraKit . select_algorithm(qr_compact!, t; positive = true ))
577- return qr_compact!(t, VC, alg )
578- elseif kind == :polar
579- alg = get(kwargs, :alg, MatrixAlgebraKit . select_algorithm(left_polar!, t))
580- return left_polar!(t, VC, alg)
581- elseif kind == :svd && iszero(atol) && iszero(rtol )
582- alg = get(kwargs, :alg, MatrixAlgebraKit. select_algorithm (svd_compact!, t) )
586+ alg_qr′ = MatrixAlgebraKit. _select_algorithm (qr_compact!, t, alg_qr )
587+ return qr_compact!(t, VC, alg_qr′ )
588+ end
589+
590+ if kind == :polar
591+ alg_polar′ = MatrixAlgebraKit . _select_algorithm(left_polar!, t, alg_polar )
592+ return left_polar!(t, VC, alg_polar′)
593+ end
594+
595+ if kind == :svd && isnothing(trunc )
596+ alg_svd′ = MatrixAlgebraKit. _select_algorithm (svd_compact!, t, alg_svd )
583597 V, C = VC
584598 S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C))
585- U, S, Vᴴ = svd_compact!(t, (V, S, C), alg )
599+ U, S, Vᴴ = svd_compact!(t, (V, S, C), alg_svd′ )
586600 return U, lmul!(S, Vᴴ)
587- elseif kind == :svd
588- alg_svd = MatrixAlgebraKit. select_algorithm(svd_compact!, t)
589- trunc = MatrixAlgebraKit. TruncationKeepAbove(atol, rtol)
590- alg = get(kwargs, :alg, MatrixAlgebraKit. TruncatedAlgorithm(alg_svd, trunc))
601+ end
602+
603+ if kind == :svd
604+ alg_svd′ = MatrixAlgebraKit. _select_algorithm(svd_compact!, t, alg_svd)
605+ alg_svd_trunc = MatrixAlgebraKit. select_algorithm(svd_trunc!, t; trunc,
606+ alg= alg_svd′)
591607 V, C = VC
592608 S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C))
593- U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg )
609+ U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg_svd_trunc )
594610 return U, lmul!(S, Vᴴ)
611+ end
612+
613+ throw(ArgumentError(" `left_orth!` received unknown value `kind = $kind `" ))
614+ end
615+
595616 else
596617 throw(ArgumentError(" `left_orth!` received unknown value `kind = $kind `" ))
597618 end
0 commit comments