Skip to content

Commit e1b953a

Browse files
authored
Support different ortho directions in factorize (#18)
1 parent e37a397 commit e1b953a

File tree

3 files changed

+38
-21
lines changed

3 files changed

+38
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorBase"
22
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/ITensorBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ using UnspecifiedTypes: UnspecifiedZero
165165
function specify_eltype(a::Zeros{UnspecifiedZero}, elt::Type)
166166
return Zeros{elt}(axes(a))
167167
end
168+
function specify_eltype(a::AbstractArray, elt::Type)
169+
return a
170+
end
168171

169172
# TODO: Use `adapt` to reach down into the storage.
170173
function specify_eltype!(a::AbstractITensor, elt::Type)

src/quirks.jl

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,34 +36,48 @@ function onehot(iv::Pair{<:Index,<:Int})
3636
return a
3737
end
3838

39+
# TODO: This is just a stand-in for truncated SVD
40+
# that only makes use of `maxdim`, just to get some
41+
# functionality running in `ITensorMPS.jl`.
42+
# Define a proper truncated SVD in
43+
# `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
44+
function svd_truncated(a::AbstractITensor, codomain_inds; maxdim)
45+
U, S, V = svd(a, codomain_inds)
46+
r = Base.OneTo(min(maxdim, minimum(Int.(size(S)))))
47+
u = commonind(U, S)
48+
v = commonind(V, S)
49+
us = uniqueinds(U, S)
50+
vs = uniqueinds(V, S)
51+
U′ = U[(us .=> :)..., u => r]
52+
S′ = S[u => r, v => r]
53+
V′ = V[v => r, (vs .=> :)...]
54+
return U′, S′, V′
55+
end
56+
3957
using LinearAlgebra: qr, svd
4058
# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
4159
function factorize(
42-
a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, kwargs...
60+
a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, ortho="left", kwargs...
4361
)
4462
# TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`?
4563
# See https://github.com/ITensor/NamedDimsArrays.jl/issues/22.
46-
codomain_inds′ = intersect(inds(a), codomain_inds)
47-
if isnothing(maxdim) && isnothing(cutoff)
48-
Q, R = qr(a, codomain_inds′)
49-
return Q, R, (; truncerr=zero(Bool),)
64+
codomain_inds′ = if ortho == "left"
65+
intersect(inds(a), codomain_inds)
66+
elseif ortho == "right"
67+
setdiff(inds(a), codomain_inds)
5068
else
51-
U, S, V = svd(a, codomain_inds′)
52-
# TODO: This is just a stand-in for truncated SVD
53-
# that only makes use of `maxdim`, just to get some
54-
# functionality running in `ITensorMPS.jl`.
55-
# Define a proper truncated SVD in
56-
# `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
57-
r = Base.OneTo(min(maxdim, minimum(Int.(size(S)))))
58-
u = commonind(U, S)
59-
v = commonind(V, S)
60-
us = uniqueinds(U, S)
61-
vs = uniqueinds(V, S)
62-
U′ = U[(us .=> :)..., u => r]
63-
S′ = S[u => r, v => r]
64-
V′ = V[v => r, (vs .=> :)...]
65-
return U′, S′ * V′, (; truncerr=zero(Bool),)
69+
error("Bad `ortho` input.")
70+
end
71+
F1, F2 = if isnothing(maxdim) && isnothing(cutoff)
72+
qr(a, codomain_inds′)
73+
else
74+
U, S, V = svd_truncated(a, codomain_inds′; maxdim)
75+
U, S * V
76+
end
77+
if ortho == "right"
78+
F2, F1 = F1, F2
6679
end
80+
return F1, F2, (; truncerr=zero(Bool),)
6781
end
6882

6983
# TODO: Used in `ITensorMPS.jl`, decide where or if to define it.

0 commit comments

Comments
 (0)