@@ -36,34 +36,48 @@ function onehot(iv::Pair{<:Index,<:Int})
3636 return a
3737end
3838
39+ # TODO : This is just a stand-in for truncated SVD
40+ # that only makes use of `maxdim`, just to get some
41+ # functionality running in `ITensorMPS.jl`.
42+ # Define a proper truncated SVD in
43+ # `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
44+ function svd_truncated (a:: AbstractITensor , codomain_inds; maxdim)
45+ U, S, V = svd (a, codomain_inds)
46+ r = Base. OneTo (min (maxdim, minimum (Int .(size (S)))))
47+ u = commonind (U, S)
48+ v = commonind (V, S)
49+ us = uniqueinds (U, S)
50+ vs = uniqueinds (V, S)
51+ U′ = U[(us .=> :). .. , u => r]
52+ S′ = S[u => r, v => r]
53+ V′ = V[v => r, (vs .=> :). .. ]
54+ return U′, S′, V′
55+ end
56+
3957using LinearAlgebra: qr, svd
4058# TODO : Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
4159function factorize (
42- a:: AbstractITensor , codomain_inds; maxdim= nothing , cutoff= nothing , kwargs...
60+ a:: AbstractITensor , codomain_inds; maxdim= nothing , cutoff= nothing , ortho = " left " , kwargs...
4361)
4462 # TODO : Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`?
4563 # See https://github.com/ITensor/NamedDimsArrays.jl/issues/22.
46- codomain_inds′ = intersect ( inds (a), codomain_inds)
47- if isnothing (maxdim) && isnothing (cutoff )
48- Q, R = qr (a, codomain_inds′)
49- return Q, R, (; truncerr = zero (Bool), )
64+ codomain_inds′ = if ortho == " left "
65+ intersect ( inds (a), codomain_inds )
66+ elseif ortho == " right "
67+ setdiff ( inds (a), codomain_inds )
5068 else
51- U, S, V = svd (a, codomain_inds′)
52- # TODO : This is just a stand-in for truncated SVD
53- # that only makes use of `maxdim`, just to get some
54- # functionality running in `ITensorMPS.jl`.
55- # Define a proper truncated SVD in
56- # `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
57- r = Base. OneTo (min (maxdim, minimum (Int .(size (S)))))
58- u = commonind (U, S)
59- v = commonind (V, S)
60- us = uniqueinds (U, S)
61- vs = uniqueinds (V, S)
62- U′ = U[(us .=> :). .. , u => r]
63- S′ = S[u => r, v => r]
64- V′ = V[v => r, (vs .=> :). .. ]
65- return U′, S′ * V′, (; truncerr= zero (Bool),)
69+ error (" Bad `ortho` input." )
70+ end
71+ F1, F2 = if isnothing (maxdim) && isnothing (cutoff)
72+ qr (a, codomain_inds′)
73+ else
74+ U, S, V = svd_truncated (a, codomain_inds′; maxdim)
75+ U, S * V
76+ end
77+ if ortho == " right"
78+ F2, F1 = F1, F2
6679 end
80+ return F1, F2, (; truncerr= zero (Bool),)
6781end
6882
6983# TODO : Used in `ITensorMPS.jl`, decide where or if to define it.
0 commit comments