@@ -86,13 +86,17 @@ function ManifoldPoint(state::FiniteMPS, envs)
8686end
8787function ManifoldPoint(state:: InfiniteMPS , envs)
8888 δmin = sqrt(eps(real(scalartype(state))))
89- g_and_ρ = tmap(1 : length(state); scheduler= MPSKit. Defaults. scheduler[]) do i
89+ Tg = Core. Compiler. return_type(Grassmann. project,
90+ Tuple{eltype(state. AL),eltype(state. AL)})
91+ g = similar(state. AL, Tg)
92+ ρ = similar(state. C)
93+ tforeach(1 : length(state); scheduler= MPSKit. Defaults. scheduler[]) do i
9094 AC′ = MPSKit.∂∂AC(i, state, envs. operator, envs) * state. AC[i]
91- g = Grassmann. project(AC′, state. AL[i])
92- ρ = regularize(state. C[i], max(norm(g) / 10 , δmin))
93- return g, ρ
95+ g[i] = Grassmann. project(AC′, state. AL[i])
96+ ρ[i] = regularize(state. C[i], max(norm(g[i] ) / 10 , δmin))
97+ return nothing
9498 end
95- return ManifoldPoint(state, envs, first.(g_and_ρ), last.(g_and_ρ) )
99+ return ManifoldPoint(state, envs, g, ρ )
96100end
97101
98102function ManifoldPoint(state:: MultilineMPS , envs)
@@ -125,7 +129,9 @@ cell as tangent vectors on Grassmann manifolds.
125129"""
126130function fg(x:: ManifoldPoint{T} ) where {T<: Union{InfiniteMPS,FiniteMPS} }
127131 # the gradient I want to return is the preconditioned gradient!
128- g_prec = tmap(eachindex(x. g); scheduler= MPSKit. Defaults. scheduler[]) do i
132+ Tg = Core. Compiler. return_type(PrecGrad, Tuple{eltype(x. g),eltype(x. Rhoreg)})
133+ g_prec = similar(x. g, Tg)
134+ tmap!(g_prec, eachindex(x. g); scheduler= MPSKit. Defaults. scheduler[]) do i
129135 return PrecGrad(rmul!(copy(x. g[i]), x. state. C[i]' ), x. Rhoreg[i])
130136 end
131137
@@ -159,7 +165,8 @@ function retract(x::ManifoldPoint{<:MultilineMPS}, tg, alpha)
159165 g = reshape(tg, size(x. state))
160166
161167 nal = similar(x. state. AL)
162- h = tmap(eachindex(g); scheduler= MPSKit. Defaults. scheduler[]) do i
168+ h = similar(tg)
169+ tmap!(h, eachindex(g); scheduler= MPSKit. Defaults. scheduler[]) do i
163170 nal[i], th = Grassmann. retract(x. state. AL[i], g[i]. Pg, alpha)
164171 return PrecGrad(th)
165172 end
@@ -179,7 +186,7 @@ function retract(x::ManifoldPoint{<:InfiniteMPS}, g, alpha)
179186 nal = similar(state. AL)
180187 h = similar(g) # The tangent at the end-point
181188
182- h = tmap( eachindex(g); scheduler= MPSKit. Defaults. scheduler[]) do i
189+ tmap!(h, eachindex(g); scheduler= MPSKit. Defaults. scheduler[]) do i
183190 nal[i], th = Grassmann. retract(state. AL[i], g[i]. Pg, alpha)
184191 return PrecGrad(th)
185192 end
@@ -217,10 +224,11 @@ Transport a tangent vector `h` along the retraction from `x` in direction `g` by
217224`alpha`. `xp` is the end-point of the retraction.
218225"""
219226function transport!(h, x, g, alpha, xp)
220- return tmap!(h, eachindex (h); scheduler= MPSKit. Defaults. scheduler[]) do i
221- return PrecGrad(Grassmann. transport!(h[i]. Pg, x. state. AL[i], g[i]. Pg, alpha,
222- xp. state. AL[i]))
227+ tforeach( 1 : length (h); scheduler= MPSKit. Defaults. scheduler[]) do i
228+ return h[i] = PrecGrad(Grassmann. transport!(h[i]. Pg, x. state. AL[i], g[i]. Pg, alpha,
229+ xp. state. AL[i]))
223230 end
231+ return h
224232end
225233
226234"""
0 commit comments