Skip to content

Commit 70d1963

Browse files
committed
replace tmap with tmap!
1 parent 62bbefa commit 70d1963

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

src/algorithms/grassmann.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,17 @@ function ManifoldPoint(state::FiniteMPS, envs)
8686
end
8787
function ManifoldPoint(state::InfiniteMPS, envs)
8888
δmin = sqrt(eps(real(scalartype(state))))
89-
g_and_ρ = tmap(1:length(state); scheduler=MPSKit.Defaults.scheduler[]) do i
89+
Tg = Core.Compiler.return_type(Grassmann.project,
90+
Tuple{eltype(state.AL),eltype(state.AL)})
91+
g = similar(state.AL, Tg)
92+
ρ = similar(state.C)
93+
tforeach(1:length(state); scheduler=MPSKit.Defaults.scheduler[]) do i
9094
AC′ = MPSKit.∂∂AC(i, state, envs.operator, envs) * state.AC[i]
91-
g = Grassmann.project(AC′, state.AL[i])
92-
ρ = regularize(state.C[i], max(norm(g) / 10, δmin))
93-
return g, ρ
95+
g[i] = Grassmann.project(AC′, state.AL[i])
96+
ρ[i] = regularize(state.C[i], max(norm(g[i]) / 10, δmin))
97+
return nothing
9498
end
95-
return ManifoldPoint(state, envs, first.(g_and_ρ), last.(g_and_ρ))
99+
return ManifoldPoint(state, envs, g, ρ)
96100
end
97101

98102
function ManifoldPoint(state::MultilineMPS, envs)
@@ -125,7 +129,9 @@ cell as tangent vectors on Grassmann manifolds.
125129
"""
126130
function fg(x::ManifoldPoint{T}) where {T<:Union{InfiniteMPS,FiniteMPS}}
127131
# the gradient I want to return is the preconditioned gradient!
128-
g_prec = tmap(eachindex(x.g); scheduler=MPSKit.Defaults.scheduler[]) do i
132+
Tg = Core.Compiler.return_type(PrecGrad, Tuple{eltype(x.g),eltype(x.Rhoreg)})
133+
g_prec = similar(x.g, Tg)
134+
tmap!(g_prec, eachindex(x.g); scheduler=MPSKit.Defaults.scheduler[]) do i
129135
return PrecGrad(rmul!(copy(x.g[i]), x.state.C[i]'), x.Rhoreg[i])
130136
end
131137

@@ -159,7 +165,8 @@ function retract(x::ManifoldPoint{<:MultilineMPS}, tg, alpha)
159165
g = reshape(tg, size(x.state))
160166

161167
nal = similar(x.state.AL)
162-
h = tmap(eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
168+
h = similar(tg)
169+
tmap!(h, eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
163170
nal[i], th = Grassmann.retract(x.state.AL[i], g[i].Pg, alpha)
164171
return PrecGrad(th)
165172
end
@@ -179,7 +186,7 @@ function retract(x::ManifoldPoint{<:InfiniteMPS}, g, alpha)
179186
nal = similar(state.AL)
180187
h = similar(g) # The tangent at the end-point
181188

182-
h = tmap(eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
189+
tmap!(h, eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
183190
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
184191
return PrecGrad(th)
185192
end
@@ -217,10 +224,11 @@ Transport a tangent vector `h` along the retraction from `x` in direction `g` by
217224
`alpha`. `xp` is the end-point of the retraction.
218225
"""
219226
function transport!(h, x, g, alpha, xp)
220-
return tmap!(h, eachindex(h); scheduler=MPSKit.Defaults.scheduler[]) do i
221-
return PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
222-
xp.state.AL[i]))
227+
tforeach(1:length(h); scheduler=MPSKit.Defaults.scheduler[]) do i
228+
return h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
229+
xp.state.AL[i]))
223230
end
231+
return h
224232
end
225233

226234
"""

src/algorithms/timestep/tdvp.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,22 @@ function timestep(ψ::InfiniteMPS, H, t::Number, dt::Number, alg::TDVP,
2626

2727
scheduler = Defaults.scheduler[]
2828
if scheduler isa SerialScheduler
29-
temp_ACs = tmap(1:length(ψ); scheduler) do loc
29+
temp_ACs = tmap!(temp_ACs, 1:length(ψ); scheduler) do loc
3030
return integrate(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], t, dt, alg.integrator)
3131
end
32-
temp_Cs = tmap(1:length(ψ); scheduler) do loc
32+
temp_Cs = tmap!(temp_Cs, 1:length(ψ); scheduler) do loc
3333
return integrate(∂∂C(loc, ψ, H, envs), ψ.C[loc], t, dt, alg.integrator)
3434
end
3535
else
3636
@sync begin
3737
Threads.@spawn begin
38-
temp_ACs = tmap(1:length(ψ); scheduler) do loc
38+
temp_ACs = tmap!(temp_ACs, 1:length(ψ); scheduler) do loc
3939
return integrate(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], t, dt,
4040
alg.integrator)
4141
end
4242
end
4343
Threads.@spawn begin
44-
temp_Cs = tmap(1:length(ψ); scheduler) do loc
44+
temp_Cs = tmap!(temp_Cs, 1:length(ψ); scheduler) do loc
4545
return integrate(∂∂C(loc, ψ, H, envs), ψ.C[loc], t, dt, alg.integrator)
4646
end
4747
end

0 commit comments

Comments
 (0)