Skip to content

Commit 2bb044a

Browse files
committed
Improve caching and dispatch of LinearizingSavingCallback
This adds a new type, `LinearizingSavingCallbackCache` and some sub-types to allow for efficient re-use of memory as the callback executes over the course of a solve, as well as re-use of that memory in future solves when operating on a large ensemble simulation. The top-level `LinearizingSavingCallbackCache` creates thread-safe cache pool objects that are then used to acquire thread-unsafe cache pool objects to be used within a single solve. Those thread-unsafe cache pool objects can then be released and acquired anew by the next solve. The thread-unsafe pool objects allow for acquisition of pieces of memory such as temporary `u` vectors (the recusrive nature of the `LinearizingSavingCallback` means that we must allocate unknown numbers of temporary `u` vectors) and chunks of `u` blocks that are then compacted into a single large matrix in the finalize method of the callback. All these pieces of memory are stored within that set of thread-unsafe caches, and these are released back to the top-level thread-safe cache pool, for the next solve to acquire and make use of those pieces of memory in the cache pool. Using these techniques, the solve time of a large ensemble simulation with low per-simulation computation has reduced dramatically. The simulation solves a butterworth 3rd-order filter circuit over a certain timespan, swept across different simulus frequencies and circuit parameters. The parameter sweep results in a 13500-element ensemble simulation, that when run with 8 threads on a M1 Pro takes: ``` 48.364827 seconds (625.86 M allocations: 19.472 GiB, 41.81% gc time, 0.17% compilation time) ``` Now, after these caching optimizations, we solve the same ensemble in: ``` 13.208123 seconds (166.76 M allocations: 7.621 GiB, 22.21% gc time, 0.61% compilation time) ``` As a side note, the size requirements of the raw linearized solution data itself is `1.04 GB`. In general, we expect to allocate somewhere between 2-3x the final output data to account for temporaries and inefficient sharing, so while there is still some more work to be done, this gets us significantly closer to minimal overhead. This also adds a package extension on `Sundials`, as `IDA` requires that state vectors are `NVector` types, rather than `Vector{S}` types in order to not allocate.
1 parent 08933b5 commit 2bb044a

File tree

5 files changed

+472
-115
lines changed

5 files changed

+472
-115
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2121
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2222
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
2323

24+
[extensions]
25+
DiffEqCallbacksSundialsExt = "Sundials"
26+
2427
[compat]
2528
Aqua = "0.8"
2629
DataInterpolations = "4"

src/independentlylinearizedutils.jl

Lines changed: 216 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,144 @@ using SciMLBase
22

33
export IndependentlyLinearizedSolution
44

5+
6+
"""
7+
CachePool(T, alloc; thread_safe = true)
8+
9+
Simple memory-reusing cache that allows us to grow a cache and keep
10+
re-using those pieces of memory (in our case, typically `u` vectors)
11+
until the solve is finished. By default, this datastructure is made
12+
to be thread-safe by locking on every acquire and release, but it
13+
can be made thread-unsafe (and correspondingly faster) by passing
14+
`thread_safe = false` to the constructor.
15+
16+
While manual usage with `acquire!()` and `release!()` is possible,
17+
most users will want to use `@with_cache`, which provides lexically-
18+
scoped `acquire!()` and `release!()` usage automatically. Example:
19+
20+
```julia
21+
us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22+
@with_cache us u_prev begin
23+
@with_cache us u_next begin
24+
# perform tasks with these two `u` vectors
25+
end
26+
end
27+
```
28+
29+
!!! warning "Escaping values"
30+
You must not use an acquired value after you have released it;
31+
the memory may be immediately re-used by some other consumer of
32+
your cache pool. Do not allow the acquired value to escape
33+
outside of the `@with_cache` block, or past a `release!()`.
34+
"""
35+
mutable struct CachePool{T, THREAD_SAFE}
36+
const pool::Vector{T}
37+
const alloc::Function
38+
lock::ReentrantLock
39+
num_alloced::Int
40+
num_acquired::Int
41+
42+
function CachePool(T, alloc::F; thread_safe::Bool = true) where {F}
43+
return new{T,Val{thread_safe}}(T[], alloc, ReentrantLock(), 0, 0)
44+
end
45+
end
46+
const ThreadSafeCachePool{T} = CachePool{T,Val{true}}
47+
const ThreadUnsafeCachePool{T} = CachePool{T,Val{false}}
48+
49+
"""
50+
acquire!(cache::CachePool)
51+
52+
Returns a cached element of the cache pool, calling `cache.alloc()` if none
53+
are available.
54+
"""
55+
Base.@inline function acquire!(cache::CachePool{T}, _dummy = nothing) where {T}
56+
cache.num_acquired += 1
57+
if isempty(cache.pool)
58+
cache.num_alloced += 1
59+
return cache.alloc()::T
60+
end
61+
return pop!(cache.pool)
62+
end
63+
64+
"""
65+
release!(cache::CachePool, val)
66+
67+
Returns the value `val` to the cache pool.
68+
"""
69+
Base.@inline function release!(cache::CachePool, val, _dummy = nothing)
70+
push!(cache.pool, val)
71+
cache.num_acquired -= 1
72+
end
73+
74+
function is_fully_released(cache::CachePool, _dummy = nothing)
75+
return cache.num_acquired == 0
76+
end
77+
78+
# Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
79+
acquire!(cache::ThreadSafeCachePool) = @lock cache.lock acquire!(cache, nothing)
80+
release!(cache::ThreadSafeCachePool, val) = @lock cache.lock release!(cache, val, nothing)
81+
is_fully_released(cache::ThreadSafeCachePool) = @lock cache.lock is_fully_released(cache, nothing)
82+
83+
macro with_cache(cache, name, body)
84+
return quote
85+
$(esc(name)) = acquire!($(esc(cache)))
86+
try
87+
$(esc(body))
88+
finally
89+
release!($(esc(cache)), $(esc(name)))
90+
end
91+
end
92+
end
93+
94+
95+
struct IndependentlyLinearizedSolutionChunksCache{T,S}
96+
t_chunks::ThreadUnsafeCachePool{Vector{T}}
97+
u_chunks::ThreadUnsafeCachePool{Matrix{S}}
98+
time_masks::ThreadUnsafeCachePool{BitMatrix}
99+
100+
function IndependentlyLinearizedSolutionChunksCache{T,S}(num_us::Int, num_derivatives::Int, chunk_size::Int) where {T,S}
101+
t_chunks_alloc = () -> Vector{T}(undef, chunk_size)
102+
u_chunks_alloc = () -> Matrix{S}(undef, num_derivatives+1, chunk_size)
103+
time_masks_alloc = () -> BitMatrix(undef, num_us, chunk_size)
104+
return new(
105+
CachePool(Vector{T}, t_chunks_alloc; thread_safe=false),
106+
CachePool(Matrix{S}, u_chunks_alloc; thread_safe=false),
107+
CachePool(BitMatrix, time_masks_alloc; thread_safe=false),
108+
)
109+
end
110+
end
111+
5112
"""
6113
IndependentlyLinearizedSolutionChunks
7114
8115
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9116
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10117
that the solve will generate.
11118
"""
12-
mutable struct IndependentlyLinearizedSolutionChunks{T, S}
119+
mutable struct IndependentlyLinearizedSolutionChunks{T, S, N}
13120
t_chunks::Vector{Vector{T}}
14121
u_chunks::Vector{Vector{Matrix{S}}}
15122
time_masks::Vector{BitMatrix}
16123

124+
# Temporary array that gets used by `get_chunks`
125+
last_chunks::Vector{Matrix{S}}
126+
17127
# Index of next write into the last chunk
18128
u_offsets::Vector{Int}
19129
t_offset::Int
20130

131+
cache::IndependentlyLinearizedSolutionChunksCache
132+
21133
function IndependentlyLinearizedSolutionChunks{T, S}(num_us::Int, num_derivatives::Int = 0,
22-
chunk_size::Int = 100) where {T, S}
23-
return new([Vector{T}(undef, chunk_size)],
24-
[[Matrix{S}(undef, num_derivatives+1, chunk_size)] for _ in 1:num_us],
25-
[BitMatrix(undef, num_us, chunk_size)],
26-
[1 for _ in 1:num_us],
27-
1,
28-
)
134+
chunk_size::Int = 512,
135+
cache::IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size)) where {T, S}
136+
t_chunks = [acquire!(cache.t_chunks)]
137+
u_chunks = [[acquire!(cache.u_chunks)] for _ in 1:num_us]
138+
time_masks = [acquire!(cache.time_masks)]
139+
last_chunks = [u_chunks[u_idx][1] for u_idx in 1:num_us]
140+
u_offsets = [1 for _ in 1:num_us]
141+
t_offset = 1
142+
return new{T,S,num_derivatives}(t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
29143
end
30144
end
31145

@@ -44,14 +158,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
44158
end
45159
return length(ilsc.u_chunks)
46160
end
161+
num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}) where {T,S,N} = N
47162

48-
function num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks)
49-
# If we've been finalized, just return `0` (which means only the primal)
50-
if isempty(ilsc.t_chunks)
51-
return 0
52-
end
53-
return size(first(first(ilsc.u_chunks)), 1) - 1
54-
end
55163

56164
function Base.isempty(ilsc::IndependentlyLinearizedSolutionChunks)
57165
return length(ilsc.t_chunks) == 1 && ilsc.t_offset == 1
@@ -61,24 +169,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
61169
# Check if we need to allocate new `t` chunk
62170
chunksize = chunk_size(ilsc)
63171
if ilsc.t_offset > chunksize
64-
push!(ilsc.t_chunks, Vector{T}(undef, chunksize))
65-
push!(ilsc.time_masks, BitMatrix(undef, length(ilsc.u_offsets), chunksize))
172+
push!(ilsc.t_chunks, acquire!(ilsc.cache.t_chunks))
173+
push!(ilsc.time_masks, acquire!(ilsc.cache.time_masks))
66174
ilsc.t_offset = 1
67175
end
68176

69177
# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
70178
for (u_idx, u_chunks) in enumerate(ilsc.u_chunks)
71179
if ilsc.u_offsets[u_idx] > chunksize
72-
push!(u_chunks, Matrix{S}(undef, num_derivatives(ilsc)+1, chunksize))
180+
push!(u_chunks, acquire!(ilsc.cache.u_chunks))
73181
ilsc.u_offsets[u_idx] = 1
74182
end
183+
ilsc.last_chunks[u_idx] = u_chunks[end]
75184
end
76185

77186
# return the last chunk for each
78187
return (
79188
ilsc.t_chunks[end],
80189
ilsc.time_masks[end],
81-
[u_chunks[end] for u_chunks in ilsc.u_chunks],
190+
ilsc.last_chunks,
82191
)
83192
end
84193

@@ -135,16 +244,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
135244
ts, time_mask, us = get_chunks(ilsc)
136245

137246
# Store into the chunks, gated by `u_mask`
138-
for u_idx in 1:size(u, 2)
247+
@inbounds for u_idx in 1:size(u, 2)
139248
if u_mask[u_idx]
140249
for deriv_idx in 1:size(u, 1)
141250
us[u_idx][deriv_idx, ilsc.u_offsets[u_idx]] = u[deriv_idx, u_idx]
142251
end
143252
ilsc.u_offsets[u_idx] += 1
144253
end
254+
255+
# Update our `time_mask` while we're at it
256+
time_mask[u_idx, ilsc.t_offset] = u_mask[u_idx]
145257
end
146258
ts[ilsc.t_offset] = t
147-
time_mask[:, ilsc.t_offset] .= u_mask
148259
ilsc.t_offset += 1
149260
end
150261

@@ -161,7 +272,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161272
of the state variables at all timepoints, as well as an efficient `sample!()`
162273
method that can sample at arbitrary timesteps.
163274
"""
164-
mutable struct IndependentlyLinearizedSolution{T, S}
275+
mutable struct IndependentlyLinearizedSolution{T, S, N}
165276
# All timepoints, shared by all `us`
166277
ts::Vector{T}
167278

@@ -173,33 +284,42 @@ mutable struct IndependentlyLinearizedSolution{T, S}
173284
time_mask::BitMatrix
174285

175286
# Temporary object used during construction, will be set to `nothing` at the end.
176-
ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
287+
ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S,N}}
288+
ilsc_cache_pool::Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177289
end
178290
# Helper function to create an ILS wrapped around an in-progress ILSC
179-
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S}) where {T,S}
180-
ils = IndependentlyLinearizedSolution(
291+
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}, cache_pool = nothing) where {T,S,N}
292+
return IndependentlyLinearizedSolution{T,S,N}(
181293
T[],
182294
Matrix{S}[],
183295
BitMatrix(undef, 0,0),
184296
ilsc,
297+
cache_pool,
185298
)
186-
return ils
187299
end
188300
# Automatically create an ILS wrapped around an ILSC from a `prob`
189-
function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem, num_derivatives = 0)
301+
function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem, num_derivatives = 0;
302+
cache_pool = nothing,
303+
chunk_size::Int = 512)
190304
T = eltype(prob.tspan)
305+
S = eltype(prob.u0)
191306
U = isnothing(prob.u0) ? Float64 : eltype(prob.u0)
192-
N = isnothing(prob.u0) ? 0 : length(prob.u0)
193-
chunks = IndependentlyLinearizedSolutionChunks{T,U}(N, num_derivatives)
194-
return IndependentlyLinearizedSolution(chunks)
307+
num_us = isnothing(prob.u0) ? 0 : length(prob.u0)
308+
if cache_pool === nothing
309+
cache = IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size)
310+
else
311+
cache = acquire!(cache_pool)
312+
end
313+
chunks = IndependentlyLinearizedSolutionChunks{T,U}(num_us, num_derivatives, chunk_size, cache)
314+
return IndependentlyLinearizedSolution(chunks, nothing)
195315
end
196316

197-
num_derivatives(ils::IndependentlyLinearizedSolution) = !isempty(ils.us) ? size(first(ils.us), 1) : 0
317+
num_derivatives(::IndependentlyLinearizedSolution{T,S,N}) where {T,S,N} = N
198318
num_us(ils::IndependentlyLinearizedSolution) = length(ils.us)
199319
Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask)
200320
Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts)
201321

202-
function finish!(ils::IndependentlyLinearizedSolution)
322+
function finish!(ils::IndependentlyLinearizedSolution{T,S}) where {T,S}
203323
function trim_chunk(chunks::Vector, offset)
204324
chunks = [chunk for chunk in chunks]
205325
if eltype(chunks) <: AbstractVector
@@ -216,10 +336,52 @@ function finish!(ils::IndependentlyLinearizedSolution)
216336
end
217337

218338
ilsc = ils.ilsc::IndependentlyLinearizedSolutionChunks
219-
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
220-
time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
221-
us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
222-
for u_idx in 1:length(ilsc.u_chunks)]
339+
340+
chunk_len(chunk) = size(chunk, ndims(chunk))
341+
function chunks_len(chunks::Vector, offset)
342+
len = 0
343+
for chunk_idx in 1:length(chunks)-1
344+
len += chunk_len(chunks[chunk_idx])
345+
end
346+
return len + offset - 1
347+
end
348+
349+
function copy_chunk!(out::Vector, in::Vector, out_offset::Int, len=chunk_len(in))
350+
for idx in 1:len
351+
out[idx+out_offset] = in[idx]
352+
end
353+
end
354+
function copy_chunk!(out::AbstractMatrix, in::AbstractMatrix, out_offset::Int, len=chunk_len(in))
355+
for zdx in 1:size(in, 1)
356+
for idx in 1:len
357+
out[zdx, idx+out_offset] = in[zdx, idx]
358+
end
359+
end
360+
end
361+
362+
function collapse_chunks!(out, chunks, offset::Int)
363+
write_offset = 0
364+
for chunk_idx in 1:(length(chunks)-1)
365+
chunk = chunks[chunk_idx]
366+
copy_chunk!(out, chunk, write_offset)
367+
write_offset += chunk_len(chunk)
368+
end
369+
copy_chunk!(out, chunks[end], write_offset, offset-1)
370+
end
371+
372+
# Collapse t_chunks
373+
ts = Vector{T}(undef, chunks_len(ilsc.t_chunks, ilsc.t_offset))
374+
collapse_chunks!(ts, ilsc.t_chunks, ilsc.t_offset)
375+
376+
# Collapse u_chunks
377+
us = Vector{Matrix{S}}(undef, length(ilsc.u_chunks))
378+
for u_idx in 1:length(ilsc.u_chunks)
379+
us[u_idx] = Matrix{S}(undef, size(ilsc.u_chunks[u_idx][1],1), chunks_len(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx]))
380+
collapse_chunks!(us[u_idx], ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])
381+
end
382+
383+
time_mask = BitMatrix(undef, size(ilsc.time_masks[1], 1), chunks_len(ilsc.time_masks, ilsc.t_offset))
384+
collapse_chunks!(time_mask, ilsc.time_masks, ilsc.t_offset)
223385

224386
# Sanity-check lengths
225387
if length(ts) != size(time_mask, 2)
@@ -238,7 +400,24 @@ function finish!(ils::IndependentlyLinearizedSolution)
238400
throw(ArgumentError("Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens))"))
239401
end
240402

241-
# Update our struct, release the `ilsc`
403+
# Update our struct, release the `ilsc` and its caches
404+
for t_chunk in ilsc.t_chunks
405+
release!(ilsc.cache.t_chunks, t_chunk)
406+
end
407+
@assert is_fully_released(ilsc.cache.t_chunks)
408+
for u_idx in 1:length(ilsc.u_chunks)
409+
for u_chunk in ilsc.u_chunks[u_idx]
410+
release!(ilsc.cache.u_chunks, u_chunk)
411+
end
412+
end
413+
@assert is_fully_released(ilsc.cache.u_chunks)
414+
for time_mask in ilsc.time_masks
415+
release!(ilsc.cache.time_masks, time_mask)
416+
end
417+
@assert is_fully_released(ilsc.cache.time_masks)
418+
if ils.ilsc_cache_pool !== nothing
419+
release!(ils.ilsc_cache_pool, ilsc.cache)
420+
end
242421
ils.ilsc = nothing
243422
ils.ts = ts
244423
ils.us = us

0 commit comments

Comments
 (0)