Skip to content

Commit 883a4d7

Browse files
authored
Rudimentary truncation in factorize (#17)
1 parent 07a49d1 commit 883a4d7

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorBase"
22
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/quirks.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,31 @@ end
3838

3939
using LinearAlgebra: qr, svd
4040
# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
41-
function factorize(a::AbstractITensor, args...; maxdim=nothing, cutoff=nothing, kwargs...)
41+
function factorize(
42+
a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, kwargs...
43+
)
44+
# TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`?
45+
# See https://github.com/ITensor/NamedDimsArrays.jl/issues/22.
46+
codomain_inds′ = intersect(inds(a), codomain_inds)
4247
if isnothing(maxdim) && isnothing(cutoff)
43-
Q, R = qr(a, args...)
44-
return Q, R
48+
Q, R = qr(a, codomain_inds′)
49+
return Q, R, (; truncerr=zero(Bool),)
4550
else
46-
error("Truncation in `factorize` not implemented yet.")
47-
U, S, V = svd(a, args...; kwargs...)
48-
return U, S * V
51+
U, S, V = svd(a, codomain_inds′)
52+
# TODO: This is just a stand-in for truncated SVD
53+
# that only makes use of `maxdim`, just to get some
54+
# functionality running in `ITensorMPS.jl`.
55+
# Define a proper truncated SVD in
56+
# `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
57+
r = Base.OneTo(min(maxdim, minimum(Int.(size(S)))))
58+
u = commonind(U, S)
59+
v = commonind(V, S)
60+
us = uniqueinds(U, S)
61+
vs = uniqueinds(V, S)
62+
U′ = U[(us .=> :)..., u => r]
63+
S′ = S[u => r, v => r]
64+
V′ = V[v => r, (vs .=> :)...]
65+
return U′, S′ * V′, (; truncerr=zero(Bool),)
4966
end
5067
end
5168

0 commit comments

Comments
 (0)