Skip to content

Commit eca22fd

Browse files
committed
Add buffer allocator
1 parent 5d5af38 commit eca22fd

File tree

4 files changed

+124
-49
lines changed

4 files changed

+124
-49
lines changed

src/MPSKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ include("utility/logging.jl")
9999
using .IterativeLoggers
100100
include("utility/iterativesolvers.jl")
101101

102+
include("utility/allocator.jl")
102103
include("utility/styles.jl")
103104
include("utility/periodicarray.jl")
104105
include("utility/windowarray.jl")

src/algorithms/derivatives/derivatives.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ Base.:*(h::LazySum{<:Union{DerivativeOrMultiplied}}, v) = h(v)
222222
Given an operator and vector, try to construct a more efficient representation of that operator for repeated application.
223223
This should always be used in conjunction with [`unprepare_operator!!`](@ref).
224224
"""
225-
prepare_operator!!(O, backend::AbstractBackend = DefaultBackend(), allocator = DefaultAllocator()) = O
225+
prepare_operator!!(O, backend::AbstractBackend = DefaultBackend(), allocator = GrowingBuffer()) = O
226226

227227
# to make benchmark scripts run
228-
prepare_operator!!(O, x::AbstractTensorMap, backend::AbstractBackend = DefaultBackend(), allocator = DefaultAllocator()) =
228+
prepare_operator!!(O, x::AbstractTensorMap, backend::AbstractBackend = DefaultBackend(), allocator = GrowingBuffer()) =
229229
prepare_operator!!(O, backend, allocator), x
230-
unprepare_operator!!(y, O, x, backend::AbstractBackend = DefaultBackend(), allocator = DefaultAllocator()) =
230+
unprepare_operator!!(y, O, x, backend::AbstractBackend = DefaultBackend(), allocator = GrowingBuffer()) =
231231
y

src/algorithms/derivatives/mpo_derivatives.jl

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,11 @@ function prepare_operator!!(
134134
H::MPO_AC_Hamiltonian{<:MPSTensor, <:MPOTensor, <:MPSTensor},
135135
backend::AbstractBackend, allocator
136136
)
137+
cp = checkpoint(allocator)
137138
@plansor backend = backend allocator = allocator begin
138139
GL_O[-1 -2; -4 -5 -3] := H.leftenv[-1 1; -4] * H.operators[1][1 -2; -5 -3]
139140
end
141+
reset!(allocator, cp)
140142
leftenv = fuse_legs(TensorMap(GL_O), 0, 2)
141143
rightenv = TensorMap(H.rightenv)
142144

@@ -147,66 +149,76 @@ function prepare_operator!!(
147149
H::MPO_AC2_Hamiltonian{<:MPSTensor, <:MPOTensor, <:MPOTensor, <:MPSTensor},
148150
backend::AbstractBackend, allocator
149151
)
152+
cp = checkpoint(allocator)
150153
@plansor backend = backend allocator = allocator begin
151154
GL_O[-1 -2; -4 -5 -3] := H.leftenv[-1 1; -4] * H.operators[1][1 -2; -5 -3]
152155
O_GR[-1 -2 -3; -4 -5] := H.operators[2][-3 -5; -2 1] * H.rightenv[-1 1; -4]
153156
end
157+
reset!(allocator, cp)
158+
154159
leftenv = fuse_legs(GL_O isa TensorMap ? GL_O : TensorMap(GL_O), 0, 2)
155160
rightenv = fuse_legs(O_GR isa TensorMap ? O_GR : TensorMap(O_GR), 2, 0)
156161
return PrecomputedDerivative(leftenv, rightenv, backend, allocator)
157162
end
158163

159164

160165
function (H::PrecomputedDerivative)(x::AbstractTensorMap)
161-
R_fused = fuse_legs(H.rightenv, 0, 2)
166+
allocator = H.allocator
167+
cp = checkpoint(allocator)
168+
169+
R_fused = fuse_legs(H.rightenv, 0, numin(x))
162170
x_fused = fuse_legs(x, numout(x), numin(x))
163171

164-
# xR = matrix_contract(R_fused, x_fused, 1, One(), H.backend, H.allocator; transpose = true)
165172

166173
TC = TensorOperations.promote_contract(scalartype(x_fused), scalartype(R_fused))
167174
xR = TensorOperations.tensoralloc_contract(TC, x_fused, ((1,), (2,)), false, R_fused, ((1,), (2, 3)), false, ((1, 2), (3,)), Val(true), H.allocator)
168175

169-
structure_xR = TensorKit.fusionblockstructure(space(xR))
170-
structure_R = TensorKit.fusionblockstructure(space(R_fused))
171-
172-
xblocks = blocks(x_fused)
173-
for ((f₁, f₂), i1) in structure_xR.fusiontreeindices
174-
sz, str, offset = structure_xR.fusiontreestructure[i1]
175-
xr = TensorKit.Strided.StridedView(xR.data, sz, str, offset)
176-
177-
u = first(f₁.uncoupled)
178-
x = TensorKit.Strided.StridedView(xblocks[u])
179-
isempty(x) && (zerovector!(xr); continue)
180-
181-
if haskey(structure_R.fusiontreeindices, (f₁, f₂))
182-
@inbounds i = structure_R.fusiontreeindices[(f₁, f₂)]
183-
@inbounds sz, str, offset = structure_R.fusiontreestructure[i]
184-
r = TensorKit.Strided.StridedView(R_fused.data, sz, str, offset)
185-
186-
if TensorOperations.isblascontractable(r, ((1,), (2, 3))) &&
187-
TensorOperations.isblasdestination(xr, ((1,), (2, 3)))
188-
C = TensorKit.Strided.sreshape(xr, size(xr, 1), size(xr, 2) * size(xr, 3))
189-
B = TensorKit.Strided.sreshape(r, size(r, 1), size(r, 2) * size(r, 3))
190-
LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
191-
elseif sz[2] < sz[3]
192-
for k in axes(r, 2)
193-
C = xr[:, k, :]
194-
B = r[:, k, :]
195-
LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
196-
end
197-
else
198-
for k in axes(r, 3)
199-
C = xr[:, :, k]
200-
B = r[:, :, k]
201-
LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
202-
end
203-
end
204-
else
205-
zerovector!(xr)
206-
end
207-
end
176+
matrix_contract!(xR, R_fused, x_fused, 1, One(), Zero(), H.backend, H.allocator; transpose = true)
177+
178+
# structure_xR = TensorKit.fusionblockstructure(space(xR))
179+
# structure_R = TensorKit.fusionblockstructure(space(R_fused))
180+
181+
# xblocks = blocks(x_fused)
182+
# for ((f₁, f₂), i1) in structure_xR.fusiontreeindices
183+
# sz, str, offset = structure_xR.fusiontreestructure[i1]
184+
# xr = TensorKit.Strided.StridedView(xR.data, sz, str, offset)
185+
186+
# u = first(f₁.uncoupled)
187+
# x = TensorKit.Strided.StridedView(xblocks[u])
188+
# isempty(x) && (zerovector!(xr); continue)
189+
190+
# if haskey(structure_R.fusiontreeindices, (f₁, f₂))
191+
# @inbounds i = structure_R.fusiontreeindices[(f₁, f₂)]
192+
# @inbounds sz, str, offset = structure_R.fusiontreestructure[i]
193+
# r = TensorKit.Strided.StridedView(R_fused.data, sz, str, offset)
194+
195+
# if TensorOperations.isblascontractable(r, ((1,), (2, 3))) &&
196+
# TensorOperations.isblasdestination(xr, ((1,), (2, 3)))
197+
# C = TensorKit.Strided.sreshape(xr, size(xr, 1), size(xr, 2) * size(xr, 3))
198+
# B = TensorKit.Strided.sreshape(r, size(r, 1), size(r, 2) * size(r, 3))
199+
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
200+
# elseif sz[2] < sz[3]
201+
# for k in axes(r, 2)
202+
# C = xr[:, k, :]
203+
# B = r[:, k, :]
204+
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
205+
# end
206+
# else
207+
# for k in axes(r, 3)
208+
# C = xr[:, :, k]
209+
# B = r[:, :, k]
210+
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
211+
# end
212+
# end
213+
# else
214+
# zerovector!(xr)
215+
# end
216+
# end
208217

209218
LxR = H.leftenv * xR
219+
TensorOperations.tensorfree!(xR, H.allocator)
220+
221+
reset!(allocator, cp)
210222
return TensorMap{scalartype(LxR)}(LxR.data, codomain(H.leftenv) domain(H.rightenv))
211223
end
212224

@@ -219,8 +231,3 @@ const _ToPrepare = Union{
219231
function prepare_operator!!(H::Multiline{<:_ToPrepare}, backend::AbstractBackend, allocator)
220232
return Multiline(map(x -> prepare_operator!!(x, backend, allocator), parent(H)))
221233
end
222-
223-
fixedpoint(A::Union{_ToPrepare, Multiline{<:_ToPrepare}}, x₀, which::Symbol, alg::Lanczos) =
224-
fixedpoint(prepare_operator!!(A), x₀, which, alg)
225-
fixedpoint(A::Union{_ToPrepare, Multiline{<:_ToPrepare}}, x₀, which::Symbol, alg::Arnoldi) =
226-
fixedpoint(prepare_operator!!(A), x₀, which, alg)

src/utility/allocator.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
@static if isdefined(Core, :Memory)
2+
BufType = Memory{UInt8}
3+
else
4+
BufType = Vector{UInt8}
5+
end
6+
7+
const DEFAULT_SIZEHINT = 2^20 # 1MB
8+
9+
mutable struct GrowingBuffer
10+
buffer::BufType
11+
offset::UInt
12+
function GrowingBuffer(; sizehint = DEFAULT_SIZEHINT)
13+
buffer = BufType(undef, sizehint)
14+
return new(buffer, zero(UInt))
15+
end
16+
end
17+
18+
Base.length(buffer::GrowingBuffer) = length(buffer.buffer)
19+
Base.pointer(buffer::GrowingBuffer) = pointer(buffer.buffer) + buffer.offset
20+
21+
function Base.sizehint!(buffer::GrowingBuffer, n::Integer; shrink::Bool = false)
22+
n > 0 || throw(ArgumentError("invalid new buffer size"))
23+
buffer.offset == 0 || error("cannot resize a buffer that is not fully reset")
24+
25+
n = shrink ? max(n, length(buffer)) : n
26+
n = Int(Base.nextpow(2, n))
27+
28+
@static if isdefined(Core, :Memory)
29+
buffer.buffer = BufType(undef, n)
30+
else
31+
sizehint!(buffer.buffer, n)
32+
end
33+
return buffer
34+
end
35+
36+
checkpoint(buffer) = zero(UInt)
37+
reset!(buffer, checkpoint::UInt = zero(UInt)) = buffer
38+
39+
checkpoint(buffer::GrowingBuffer) = buffer.offset
40+
41+
function reset!(buffer::GrowingBuffer, checkpoint::UInt = zero(UInt))
42+
if iszero(checkpoint) && buffer.offset > length(buffer)
43+
# full reset - check for need to grow
44+
newlength = Base.nextpow(2, buffer.offset) # round to nearest larger power of 2
45+
buffer.offset = checkpoint
46+
sizehint!(buffer, newlength)
47+
else
48+
buffer.offset = checkpoint
49+
end
50+
return buffer
51+
end
52+
53+
# Allocating
54+
# ----------
55+
function TensorOperations.tensoralloc(
56+
::Type{A}, structure, ::Val{istemp}, buffer::GrowingBuffer
57+
) where {A <: AbstractArray, istemp}
58+
T = eltype(A)
59+
if istemp
60+
ptr = convert(Ptr{T}, pointer(buffer))
61+
buffer.offset += prod(structure) * sizeof(T)
62+
buffer.offset < length(buffer) &&
63+
return Base.unsafe_wrap(Array, ptr, structure)
64+
end
65+
return A(undef, structure)
66+
end
67+
TensorOperations.tensorfree!(::AbstractArray, ::GrowingBuffer) = nothing

0 commit comments

Comments
 (0)