Skip to content

Commit 011d8a5

Browse files
committed
rework orths to not take allocate output first
1 parent 783256a commit 011d8a5

File tree

2 files changed

+63
-11
lines changed

2 files changed

+63
-11
lines changed

src/tensors/factorizations/factorizations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import MatrixAlgebraKit: default_algorithm,
3636
eigh_full!, eigh_trunc!, eigh_vals!,
3737
eig_full!, eig_trunc!, eig_vals!,
3838
left_polar!, left_orth_polar!, right_polar!, right_orth_polar!,
39-
left_null_svd!, right_null_svd!,
39+
left_null_svd!, right_null_svd!, left_orth_svd!, right_orth_svd!,
4040
left_orth!, right_orth!, left_null!, right_null!,
4141
truncate!, findtruncated, findtruncated_sorted,
4242
diagview, isisometry

src/tensors/factorizations/matrixalgebrakit.jl

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -553,16 +553,6 @@ function initialize_output(::typeof(right_polar!), t::AbstractTensorMap,
553553
return P, Wᴴ
554554
end
555555

556-
# Needed to get algorithm selection to behave
557-
function left_orth_polar!(t::AbstractTensorMap, VC, alg)
558-
alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg)
559-
return left_orth_polar!(t, VC, alg′)
560-
end
561-
function right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg)
562-
alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg)
563-
return right_orth_polar!(t, CVᴴ, alg′)
564-
end
565-
566556
# Orthogonalization
567557
# -----------------
568558
function check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, ::AbstractAlgorithm)
@@ -609,6 +599,68 @@ function initialize_output(::typeof(right_orth!), t::AbstractTensorMap)
609599
return C, Vᴴ
610600
end
611601

602+
# This is a rework of the dispatch logic in order to avoid having to deal with having to
603+
# allocate the output before knowing the kind of decomposition. In particular, here I disable
604+
# providing output arguments for left_ and right_orth.
605+
# This is mainly because polar decompositions have different shapes, and SVD for Diagonal
606+
# also does
607+
function left_orth!(t::AbstractTensorMap;
608+
trunc::TruncationStrategy=notrunc(),
609+
kind=trunc == notrunc() ? :qr : :svd,
610+
alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;))
611+
trunc == notrunc() || kind === :svd ||
612+
throw(ArgumentError("truncation not supported for left_orth with kind = $kind"))
613+
614+
kind === :qr && return qr_compact!(t; alg_qr...)
615+
kind === :polar && return left_orth_polar!(t; alg_polar...)
616+
kind === :svd && return left_orth_svd!(t; trunc, alg_svd...)
617+
618+
throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`"))
619+
end
620+
function right_orth!(t::AbstractTensorMap;
621+
trunc::TruncationStrategy=notrunc(),
622+
kind=trunc == notrunc() ? :lq : :svd,
623+
alg_lq=(; positive=true), alg_polar=(;), alg_svd=(;))
624+
trunc == notrunc() || kind === :svd ||
625+
throw(ArgumentError("truncation not supported for right_orth with kind = $kind"))
626+
627+
kind === :qr && return lq_compact!(t; alg_lq...)
628+
kind === :polar && return right_orth_polar!(t; alg_polar...)
629+
kind === :svd && return right_orth_svd!(t; trunc, alg_svd...)
630+
631+
throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`"))
632+
end
633+
634+
function left_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...)
635+
alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg; kwargs...)
636+
VC = initialize_output(left_orth!, t)
637+
return left_orth_polar!(t, VC, alg′)
638+
end
639+
function left_orth_polar!(t::AbstractTensorMap, VC, alg)
640+
alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg)
641+
return left_orth_polar!(t, VC, alg′)
642+
end
643+
function right_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...)
644+
alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg; kwargs...)
645+
CVᴴ = initialize_output(right_orth!, t)
646+
return right_orth_polar!(t, CVᴴ, alg′)
647+
end
648+
function right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg)
649+
alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg)
650+
return right_orth_polar!(t, CVᴴ, alg′)
651+
end
652+
653+
function left_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...)
654+
U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) :
655+
svd_trunc!(t; trunc, kwargs...)
656+
return U, lmul!(S, Vᴴ)
657+
end
658+
function right_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...)
659+
U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) :
660+
svd_trunc!(t; trunc, kwargs...)
661+
return rmul!(U, S), Vᴴ
662+
end
663+
612664
# Nullspace
613665
# ---------
614666
function check_input(::typeof(left_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm)

0 commit comments

Comments
 (0)