@@ -20,46 +20,37 @@ hasqns(a::AbstractITensor) = all(hasqns, inds(a))
2020# TODO : Investigate this and see if we can get rid of it.
2121Base. Broadcast. extrude (a:: AbstractITensor ) = a
2222
23- # TODO : This is just a stand-in for truncated SVD
24- # that only makes use of `maxdim`, just to get some
25- # functionality running in `ITensorMPS.jl`.
26- # Define a proper truncated SVD in
27- # `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
28- function svd_truncated (a:: AbstractITensor , codomain_inds; maxdim)
29- U, S, V = svd (a, codomain_inds)
30- r = Base. OneTo (min (maxdim, minimum (Int .(size (S)))))
31- u = commonind (U, S)
32- v = commonind (V, S)
33- us = uniqueinds (U, S)
34- vs = uniqueinds (V, S)
35- U′ = U[(us .=> :). .. , u => r]
36- S′ = S[u => r, v => r]
37- V′ = V[v => r, (vs .=> :). .. ]
38- return U′, S′, V′
39- end
23+ # See: https://github.com/JuliaLang/julia/blob/v1.11.4/base/namedtuple.jl#L269
24+ # `filter(f, ::NamedTuple)` is available in Julia v1.11, delete once
25+ # we drop support for Julia v1.10.
26+ filter_namedtuple (f, xs:: NamedTuple ) = xs[filter (k -> f (xs[k]), keys (xs))]
4027
41- using LinearAlgebra: qr, svd
42- # TODO : Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
43- function factorize (
44- a:: AbstractITensor , codomain_inds; maxdim= nothing , cutoff= nothing , ortho= " left" , kwargs...
28+ function translate_factorize_kwargs (;
29+ # MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs.
30+ orth= nothing ,
31+ rtol= nothing ,
32+ maxrank= nothing ,
33+ # ITensors.jl kwargs.
34+ ortho= nothing ,
35+ cutoff= nothing ,
36+ maxdim= nothing ,
37+ kwargs... ,
4538)
46- # TODO : Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`?
47- # See https://github.com/ITensor/NamedDimsArrays.jl/issues/22.
48- codomain_inds′ = if ortho == " left"
49- intersect (inds (a), codomain_inds)
50- elseif ortho == " right"
51- setdiff (inds (a), codomain_inds)
52- else
53- error (" Bad `ortho` input." )
54- end
55- F1, F2 = if isnothing (maxdim) && isnothing (cutoff)
56- qr (a, codomain_inds′)
57- else
58- U, S, V = svd_truncated (a, codomain_inds′; maxdim)
59- U, S * V
60- end
61- if ortho == " right"
62- F2, F1 = F1, F2
63- end
64- return F1, F2, (; truncerr= zero (Bool),)
39+ orth:: Symbol = @something orth ortho :left
40+ rtol = @something rtol cutoff Some (nothing )
41+ maxrank = @something maxrank maxdim Some (nothing )
42+ ! isnothing (maxrank) && error (" `maxrank` not supported yet." )
43+ return filter_namedtuple (! isnothing, (; orth, rtol, maxrank, kwargs... ))
44+ end
45+
46+ using TensorAlgebra: TensorAlgebra, factorize
47+ function TensorAlgebra. factorize (a:: AbstractITensor , codomain_inds, domain_inds; kwargs... )
48+ return invoke (
49+ factorize,
50+ Tuple{AbstractNamedDimsArray,Any,Any},
51+ a,
52+ codomain_inds,
53+ domain_inds;
54+ translate_factorize_kwargs (; kwargs... )... ,
55+ )
6556end
0 commit comments