Skip to content

Commit 8f2d090

Browse files
committed
Fix type stability for fixedspacetruncation
1 parent 374b40e commit 8f2d090

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

src/algorithms/ctmrg/simultaneous.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,40 @@ function simultaneous_projectors(
8080
) where {E}
8181
coordinates = eachcoordinate(env, 1:4)
8282
T_dst = Base.promote_op(
83-
simultaneous_projectors, NTuple{3,Int}, typeof(enlarged_corners), typeof(alg)
83+
simultaneous_projectors,
84+
NTuple{3,Int},
85+
typeof(enlarged_corners),
86+
typeof(env),
87+
typeof(alg),
8488
)
8589
proj_and_info′ = similar(coordinates, T_dst)
8690
proj_and_info::typeof(proj_and_info′) =
8791
dtmap!!(proj_and_info′, coordinates) do coordinate
88-
coordinate′ = _next_coordinate(coordinate, size(env)[2:3]...)
89-
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
90-
return simultaneous_projectors(
91-
coordinate, enlarged_corners, @set(alg.trscheme = trscheme)
92-
)
92+
return simultaneous_projectors(coordinate, enlarged_corners, env, alg)
9393
end
9494
return _split_proj_and_info(proj_and_info)
9595
end
9696
function simultaneous_projectors(
97-
coordinate, enlarged_corners::Array{E,3}, alg::HalfInfiniteProjector
97+
coordinate,
98+
enlarged_corners::Array{E,3},
99+
env,
100+
alg::HalfInfiniteProjector,
98101
) where {E}
99-
coordinate′ = _next_coordinate(coordinate, size(enlarged_corners)[2:3]...)
102+
coordinate′ = _next_coordinate(coordinate, size(env)[2:3]...)
103+
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
104+
alg′ = @set alg.trscheme = trscheme
100105
ec = (enlarged_corners[coordinate...], enlarged_corners[coordinate′...])
101-
return compute_projector(ec, coordinate, alg)
106+
return compute_projector(ec, coordinate, alg)
102107
end
103108
function simultaneous_projectors(
104-
coordinate, enlarged_corners::Array{E,3}, alg::FullInfiniteProjector
109+
coordinate,
110+
enlarged_corners::Array{E,3},
111+
env,
112+
alg::FullInfiniteProjector,
105113
) where {E}
114+
coordinate′ = _next_coordinate(coordinate, size(env)[2:3]...)
115+
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
116+
alg′ = @set alg.trscheme = trscheme
106117
rowsize, colsize = size(enlarged_corners)[2:3]
107118
coordinate2 = _next_coordinate(coordinate, rowsize, colsize)
108119
coordinate3 = _next_coordinate(coordinate2, rowsize, colsize)
@@ -113,7 +124,7 @@ function simultaneous_projectors(
113124
enlarged_corners[coordinate2...],
114125
enlarged_corners[coordinate3...],
115126
)
116-
return compute_projector(ec, coordinate, alg)
127+
return compute_projector(ec, coordinate, alg)
117128
end
118129

119130
"""

src/utility/diffable_threads.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,16 @@ function ChainRulesCore.rrule(
3939
config::RuleConfig{>:HasReverseMode},
4040
::typeof(dtmap!!),
4141
f,
42-
C::AbstractArray,
42+
C::AbstractArray,
4343
A::AbstractArray;
4444
kwargs...,
4545
)
46-
return rrule(config, dtmap(f, A; kwargs...))
46+
C, dtmap_pullback = rrule(config, dtmap, f, A; kwargs...)
47+
function dtmap!!_pullback(dy)
48+
dtmap, df, dA = dtmap_pullback(dy)
49+
return dtmap, df, NoTangent, dA
50+
end
51+
return C, dtmap!!_pullback
4752
end
4853

4954
"""

0 commit comments

Comments
 (0)