Skip to content

Commit 5a85d32

Browse files
authored
[Fix] correctly handle eachsite for Multiline objects (#281)
* correctly handle `eachsite` for `Multiline` objects * Bump v0.13.1
1 parent 46b7c5f commit 5a85d32

File tree

6 files changed

+15
-9
lines changed

6 files changed

+15
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MPSKit"
22
uuid = "bb1c41ca-d63c-52ed-829e-0820dda26502"
33
authors = "Lukas Devos, Maarten Van Damme and contributors"
4-
version = "0.13.0"
4+
version = "0.13.1"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/algorithms/approximate/vomps.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ end
5858
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
5959
::SerialScheduler)
6060
alg_orth = QRpos()
61-
eachsite = 1:length(state.mps)
61+
6262
ACs = similar(state.mps.AC)
6363
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs
6464

65-
foreach(eachsite) do site
65+
foreach(eachsite(state.mps)) do site
6666
AC = circshift([AC_projection(CartesianIndex(row, site), state.mps, state.operator,
6767
state.envs)
6868
for row in 1:size(state.mps, 1)], 1)
@@ -78,12 +78,11 @@ end
7878
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
7979
scheduler)
8080
alg_orth = QRpos()
81-
eachsite = 1:length(state.mps)
8281

8382
ACs = similar(state.mps.AC)
8483
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs
8584

86-
tforeach(eachsite; scheduler) do site
85+
tforeach(eachsite(state.mps); scheduler) do site
8786
local AC, C
8887
@sync begin
8988
Threads.@spawn begin

src/algorithms/groundstate/vumps.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,12 @@ function localupdate_step!(it::IterativeSolver{<:VUMPS}, state,
104104
alg_orth = QRpos()
105105

106106
mps = state.mps
107-
eachsite = 1:length(mps)
108107
src_Cs = mps isa Multiline ? eachcol(mps.C) : mps.C
109108
src_ACs = mps isa Multiline ? eachcol(mps.AC) : mps.AC
110109
ACs = similar(mps.AC)
111110
dst_ACs = mps isa Multiline ? eachcol(ACs) : ACs
112111

113-
tforeach(eachsite, src_ACs, src_Cs; scheduler) do site, AC₀, C₀
112+
tforeach(eachsite(mps), src_ACs, src_Cs; scheduler) do site, AC₀, C₀
114113
dst_ACs[site] = _localupdate_vumps_step!(site, mps, state.operator, state.envs,
115114
AC₀, C₀; parallel=false, alg_orth,
116115
state.which, alg_eigsolve)

src/algorithms/statmech/vomps.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,12 @@ function localupdate_step!(::IterativeSolver{<:VOMPS}, state,
100100
scheduler=Defaults.scheduler[])
101101
alg_orth = QRpos()
102102
mps = state.mps
103-
eachsite = 1:length(mps)
104103
src_Cs = mps isa Multiline ? eachcol(mps.C) : mps.C
105104
src_ACs = mps isa Multiline ? eachcol(mps.AC) : mps.AC
106105
ACs = similar(mps.AC)
107106
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs
108107

109-
tforeach(eachsite, src_ACs, src_Cs; scheduler) do site, AC₀, C₀
108+
tforeach(eachsite(mps), src_ACs, src_Cs; scheduler) do site, AC₀, C₀
110109
dst_ACs[site] = _localupdate_vomps_step!(site, mps, state.operator, state.envs,
111110
AC₀, C₀; alg_orth, parallel=false)
112111
return nothing

src/states/abstractmps.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,10 @@ physicalspace(A::MPSTensor) = space(A, 2)
197197
physicalspace(A::GenericMPSTensor) = prod(x -> space(A, x), 2:(numind(A) - 1))
198198
physicalspace(O::MPOTensor) = space(O, 2)
199199
physicalspace(O::AbstractBlockTensorMap{<:Any,<:Any,2,2}) = only(space(O, 2))
200+
201+
"""
202+
eachsite(state::AbstractMPS)
203+
204+
Return an iterator over the sites of the MPS `state`.
205+
"""
206+
eachsite::AbstractMPS) = eachindex(ψ)

src/utility/multiline.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function Base.axes(m::Multiline, i::Int)
3030
end
3131
Base.eachindex(m::Multiline) = CartesianIndices(size(m))
3232

33+
eachsite(m::Multiline) = eachsite(first(parent(m)))
34+
3335
Base.getindex(m::Multiline, i::Int) = getindex(parent(m), i)
3436
Base.setindex!(m::Multiline, v, i::Int) = (setindex!(parent(m), v, i); m)
3537

0 commit comments

Comments
 (0)