Skip to content

Commit 58d4681

Browse files
committed
Improve caching and dispatch of LinearizingSavingCallback
This adds a new type, `LinearizingSavingCallbackCache` and some sub-types to allow for efficient re-use of memory as the callback executes over the course of a solve, as well as re-use of that memory in future solves when operating on a large ensemble simulation. The top-level `LinearizingSavingCallbackCache` creates thread-safe cache pool objects that are then used to acquire thread-unsafe cache pool objects to be used within a single solve. Those thread-unsafe cache pool objects can then be released and acquired anew by the next solve. The thread-unsafe pool objects allow for acquisition of pieces of memory such as temporary `u` vectors (the recusrive nature of the `LinearizingSavingCallback` means that we must allocate unknown numbers of temporary `u` vectors) and chunks of `u` blocks that are then compacted into a single large matrix in the finalize method of the callback. All these pieces of memory are stored within that set of thread-unsafe caches, and these are released back to the top-level thread-safe cache pool, for the next solve to acquire and make use of those pieces of memory in the cache pool. Using these techniques, the solve time of a large ensemble simulation with low per-simulation computation has reduced dramatically. The simulation solves a butterworth 3rd-order filter circuit over a certain timespan, swept across different simulus frequencies and circuit parameters. The parameter sweep results in a 13500-element ensemble simulation, that when run with 8 threads on a M1 Pro takes: ``` 48.364827 seconds (625.86 M allocations: 19.472 GiB, 41.81% gc time, 0.17% compilation time) ``` Now, after these caching optimizations, we solve the same ensemble in: ``` 13.208123 seconds (166.76 M allocations: 7.621 GiB, 22.21% gc time, 0.61% compilation time) ``` As a side note, the size requirements of the raw linearized solution data itself is `1.04 GB`. In general, we expect to allocate somewhere between 2-3x the final output data to account for temporaries and inefficient sharing, so while there is still some more work to be done, this gets us significantly closer to minimal overhead. This also adds a package extension on `Sundials`, as `IDA` requires that state vectors are `NVector` types, rather than `Vector{S}` types in order to not allocate.
1 parent be310b8 commit 58d4681

File tree

6 files changed

+501
-125
lines changed

6 files changed

+501
-125
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2121
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2222
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
2323

24+
[extensions]
25+
DiffEqCallbacksSundialsExt = "Sundials"
26+
2427
[compat]
2528
Aqua = "0.8"
2629
DataInterpolations = "4"

ext/DiffEqCallbacksSundialsExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module DiffEqCallbacksSundialsExt
2+
3+
using Sundials: NVector, IDA
4+
import DiffEqCallbacks: solver_state_alloc, solver_state_type
5+
6+
7+
# Allocator; `U` is typically something like `Vector{Float64}`
8+
solver_state_alloc(solver::IDA, U::DataType, num_us::Int) = () -> NVector(U(undef, num_us))
9+
10+
# Type of `solver_state_alloc`, which is just `NVector`
11+
solver_state_type(solver::IDA, U::DataType) = NVector
12+
13+
end # module

src/independentlylinearizedutils.jl

Lines changed: 217 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,144 @@ using SciMLBase
22

33
export IndependentlyLinearizedSolution
44

5+
6+
"""
7+
CachePool(T, alloc; thread_safe = true)
8+
9+
Simple memory-reusing cache that allows us to grow a cache and keep
10+
re-using those pieces of memory (in our case, typically `u` vectors)
11+
until the solve is finished. By default, this datastructure is made
12+
to be thread-safe by locking on every acquire and release, but it
13+
can be made thread-unsafe (and correspondingly faster) by passing
14+
`thread_safe = false` to the constructor.
15+
16+
While manual usage with `acquire!()` and `release!()` is possible,
17+
most users will want to use `@with_cache`, which provides lexically-
18+
scoped `acquire!()` and `release!()` usage automatically. Example:
19+
20+
```julia
21+
us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22+
@with_cache us u_prev begin
23+
@with_cache us u_next begin
24+
# perform tasks with these two `u` vectors
25+
end
26+
end
27+
```
28+
29+
!!! warning "Escaping values"
30+
You must not use an acquired value after you have released it;
31+
the memory may be immediately re-used by some other consumer of
32+
your cache pool. Do not allow the acquired value to escape
33+
outside of the `@with_cache` block, or past a `release!()`.
34+
"""
35+
mutable struct CachePool{T, THREAD_SAFE}
36+
const pool::Vector{T}
37+
const alloc::Function
38+
lock::ReentrantLock
39+
num_allocated::Int
40+
num_acquired::Int
41+
42+
function CachePool(T, alloc::F; thread_safe::Bool = true) where {F}
43+
return new{T,Val{thread_safe}}(T[], alloc, ReentrantLock(), 0, 0)
44+
end
45+
end
46+
const ThreadSafeCachePool{T} = CachePool{T,Val{true}}
47+
const ThreadUnsafeCachePool{T} = CachePool{T,Val{false}}
48+
49+
"""
50+
acquire!(cache::CachePool)
51+
52+
Returns a cached element of the cache pool, calling `cache.alloc()` if none
53+
are available.
54+
"""
55+
Base.@inline function acquire!(cache::CachePool{T}, _dummy = nothing) where {T}
56+
cache.num_acquired += 1
57+
if isempty(cache.pool)
58+
cache.num_allocated += 1
59+
return cache.alloc()::T
60+
end
61+
return pop!(cache.pool)
62+
end
63+
64+
"""
65+
release!(cache::CachePool, val)
66+
67+
Returns the value `val` to the cache pool.
68+
"""
69+
Base.@inline function release!(cache::CachePool, val, _dummy = nothing)
70+
push!(cache.pool, val)
71+
cache.num_acquired -= 1
72+
end
73+
74+
function is_fully_released(cache::CachePool, _dummy = nothing)
75+
return cache.num_acquired == 0
76+
end
77+
78+
# Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
79+
acquire!(cache::ThreadSafeCachePool) = @lock cache.lock acquire!(cache, nothing)
80+
release!(cache::ThreadSafeCachePool, val) = @lock cache.lock release!(cache, val, nothing)
81+
is_fully_released(cache::ThreadSafeCachePool) = @lock cache.lock is_fully_released(cache, nothing)
82+
83+
macro with_cache(cache, name, body)
84+
return quote
85+
$(esc(name)) = acquire!($(esc(cache)))
86+
try
87+
$(esc(body))
88+
finally
89+
release!($(esc(cache)), $(esc(name)))
90+
end
91+
end
92+
end
93+
94+
95+
struct IndependentlyLinearizedSolutionChunksCache{T,S}
96+
t_chunks::ThreadUnsafeCachePool{Vector{T}}
97+
u_chunks::ThreadUnsafeCachePool{Matrix{S}}
98+
time_masks::ThreadUnsafeCachePool{BitMatrix}
99+
100+
function IndependentlyLinearizedSolutionChunksCache{T,S}(num_us::Int, num_derivatives::Int, chunk_size::Int) where {T,S}
101+
t_chunks_alloc = () -> Vector{T}(undef, chunk_size)
102+
u_chunks_alloc = () -> Matrix{S}(undef, num_derivatives+1, chunk_size)
103+
time_masks_alloc = () -> BitMatrix(undef, num_us, chunk_size)
104+
return new(
105+
CachePool(Vector{T}, t_chunks_alloc; thread_safe=false),
106+
CachePool(Matrix{S}, u_chunks_alloc; thread_safe=false),
107+
CachePool(BitMatrix, time_masks_alloc; thread_safe=false),
108+
)
109+
end
110+
end
111+
5112
"""
6113
IndependentlyLinearizedSolutionChunks
7114
8115
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9116
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10117
that the solve will generate.
11118
"""
12-
mutable struct IndependentlyLinearizedSolutionChunks{T, S}
119+
mutable struct IndependentlyLinearizedSolutionChunks{T, S, N}
13120
t_chunks::Vector{Vector{T}}
14121
u_chunks::Vector{Vector{Matrix{S}}}
15122
time_masks::Vector{BitMatrix}
16123

124+
# Temporary array that gets used by `get_chunks`
125+
last_chunks::Vector{Matrix{S}}
126+
17127
# Index of next write into the last chunk
18128
u_offsets::Vector{Int}
19129
t_offset::Int
20130

21-
function IndependentlyLinearizedSolutionChunks{T, S}(
22-
num_us::Int, num_derivatives::Int = 0,
23-
chunk_size::Int = 100) where {T, S}
24-
return new([Vector{T}(undef, chunk_size)],
25-
[[Matrix{S}(undef, num_derivatives + 1, chunk_size)] for _ in 1:num_us],
26-
[BitMatrix(undef, num_us, chunk_size)],
27-
[1 for _ in 1:num_us],
28-
1
29-
)
131+
cache::IndependentlyLinearizedSolutionChunksCache
132+
133+
function IndependentlyLinearizedSolutionChunks{T, S}(num_us::Int, num_derivatives::Int = 0,
134+
chunk_size::Int = 512,
135+
cache::IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size)) where {T, S}
136+
t_chunks = [acquire!(cache.t_chunks)]
137+
u_chunks = [[acquire!(cache.u_chunks)] for _ in 1:num_us]
138+
time_masks = [acquire!(cache.time_masks)]
139+
last_chunks = [u_chunks[u_idx][1] for u_idx in 1:num_us]
140+
u_offsets = [1 for _ in 1:num_us]
141+
t_offset = 1
142+
return new{T,S,num_derivatives}(t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
30143
end
31144
end
32145

@@ -45,14 +158,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
45158
end
46159
return length(ilsc.u_chunks)
47160
end
161+
num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}) where {T,S,N} = N
48162

49-
function num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks)
50-
# If we've been finalized, just return `0` (which means only the primal)
51-
if isempty(ilsc.t_chunks)
52-
return 0
53-
end
54-
return size(first(first(ilsc.u_chunks)), 1) - 1
55-
end
56163

57164
function Base.isempty(ilsc::IndependentlyLinearizedSolutionChunks)
58165
return length(ilsc.t_chunks) == 1 && ilsc.t_offset == 1
@@ -62,24 +169,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
62169
# Check if we need to allocate new `t` chunk
63170
chunksize = chunk_size(ilsc)
64171
if ilsc.t_offset > chunksize
65-
push!(ilsc.t_chunks, Vector{T}(undef, chunksize))
66-
push!(ilsc.time_masks, BitMatrix(undef, length(ilsc.u_offsets), chunksize))
172+
push!(ilsc.t_chunks, acquire!(ilsc.cache.t_chunks))
173+
push!(ilsc.time_masks, acquire!(ilsc.cache.time_masks))
67174
ilsc.t_offset = 1
68175
end
69176

70177
# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
71178
for (u_idx, u_chunks) in enumerate(ilsc.u_chunks)
72179
if ilsc.u_offsets[u_idx] > chunksize
73-
push!(u_chunks, Matrix{S}(undef, num_derivatives(ilsc) + 1, chunksize))
180+
push!(u_chunks, acquire!(ilsc.cache.u_chunks))
74181
ilsc.u_offsets[u_idx] = 1
75182
end
183+
ilsc.last_chunks[u_idx] = u_chunks[end]
76184
end
77185

78186
# return the last chunk for each
79187
return (
80188
ilsc.t_chunks[end],
81189
ilsc.time_masks[end],
82-
[u_chunks[end] for u_chunks in ilsc.u_chunks]
190+
ilsc.last_chunks,
83191
)
84192
end
85193

@@ -137,16 +245,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
137245
ts, time_mask, us = get_chunks(ilsc)
138246

139247
# Store into the chunks, gated by `u_mask`
140-
for u_idx in 1:size(u, 2)
248+
@inbounds for u_idx in 1:size(u, 2)
141249
if u_mask[u_idx]
142250
for deriv_idx in 1:size(u, 1)
143251
us[u_idx][deriv_idx, ilsc.u_offsets[u_idx]] = u[deriv_idx, u_idx]
144252
end
145253
ilsc.u_offsets[u_idx] += 1
146254
end
255+
256+
# Update our `time_mask` while we're at it
257+
time_mask[u_idx, ilsc.t_offset] = u_mask[u_idx]
147258
end
148259
ts[ilsc.t_offset] = t
149-
time_mask[:, ilsc.t_offset] .= u_mask
150260
ilsc.t_offset += 1
151261
end
152262

@@ -161,7 +271,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161271
of the state variables at all timepoints, as well as an efficient `sample!()`
162272
method that can sample at arbitrary timesteps.
163273
"""
164-
mutable struct IndependentlyLinearizedSolution{T, S}
274+
mutable struct IndependentlyLinearizedSolution{T, S, N}
165275
# All timepoints, shared by all `us`
166276
ts::Vector{T}
167277

@@ -173,32 +283,37 @@ mutable struct IndependentlyLinearizedSolution{T, S}
173283
time_mask::BitMatrix
174284

175285
# Temporary object used during construction, will be set to `nothing` at the end.
176-
ilsc::Union{Nothing, IndependentlyLinearizedSolutionChunks{T, S}}
286+
ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S,N}}
287+
ilsc_cache_pool::Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177288
end
178289
# Helper function to create an ILS wrapped around an in-progress ILSC
179-
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{
180-
T, S}) where {T, S}
181-
ils = IndependentlyLinearizedSolution(
290+
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}, cache_pool = nothing) where {T,S,N}
291+
return IndependentlyLinearizedSolution{T,S,N}(
182292
T[],
183293
Matrix{S}[],
184-
BitMatrix(undef, 0, 0),
185-
ilsc
294+
BitMatrix(undef, 0,0),
295+
ilsc,
296+
cache_pool,
186297
)
187-
return ils
188298
end
189299
# Automatically create an ILS wrapped around an ILSC from a `prob`
190-
function IndependentlyLinearizedSolution(
191-
prob::SciMLBase.AbstractDEProblem, num_derivatives = 0)
300+
function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem, num_derivatives = 0;
301+
cache_pool = nothing,
302+
chunk_size::Int = 512)
192303
T = eltype(prob.tspan)
304+
S = eltype(prob.u0)
193305
U = isnothing(prob.u0) ? Float64 : eltype(prob.u0)
194-
N = isnothing(prob.u0) ? 0 : length(prob.u0)
195-
chunks = IndependentlyLinearizedSolutionChunks{T, U}(N, num_derivatives)
196-
return IndependentlyLinearizedSolution(chunks)
306+
num_us = isnothing(prob.u0) ? 0 : length(prob.u0)
307+
if cache_pool === nothing
308+
cache = IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size)
309+
else
310+
cache = acquire!(cache_pool)
311+
end
312+
chunks = IndependentlyLinearizedSolutionChunks{T,U}(num_us, num_derivatives, chunk_size, cache)
313+
return IndependentlyLinearizedSolution(chunks, cache_pool)
197314
end
198315

199-
function num_derivatives(ils::IndependentlyLinearizedSolution)
200-
!isempty(ils.us) ? size(first(ils.us), 1) : 0
201-
end
316+
num_derivatives(::IndependentlyLinearizedSolution{T,S,N}) where {T,S,N} = N
202317
num_us(ils::IndependentlyLinearizedSolution) = length(ils.us)
203318
Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask)
204319
Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts)
@@ -226,10 +341,51 @@ function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where
226341
us = Vector{Matrix{S}}()
227342
time_mask = BitMatrix(undef, 0, 0)
228343
else
229-
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
230-
time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
231-
us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
232-
for u_idx in 1:length(ilsc.u_chunks)]
344+
chunk_len(chunk) = size(chunk, ndims(chunk))
345+
function chunks_len(chunks::Vector, offset)
346+
len = 0
347+
for chunk_idx in 1:length(chunks)-1
348+
len += chunk_len(chunks[chunk_idx])
349+
end
350+
return len + offset - 1
351+
end
352+
353+
function copy_chunk!(out::Vector, in::Vector, out_offset::Int, len=chunk_len(in))
354+
for idx in 1:len
355+
out[idx+out_offset] = in[idx]
356+
end
357+
end
358+
function copy_chunk!(out::AbstractMatrix, in::AbstractMatrix, out_offset::Int, len=chunk_len(in))
359+
for zdx in 1:size(in, 1)
360+
for idx in 1:len
361+
out[zdx, idx+out_offset] = in[zdx, idx]
362+
end
363+
end
364+
end
365+
366+
function collapse_chunks!(out, chunks, offset::Int)
367+
write_offset = 0
368+
for chunk_idx in 1:(length(chunks)-1)
369+
chunk = chunks[chunk_idx]
370+
copy_chunk!(out, chunk, write_offset)
371+
write_offset += chunk_len(chunk)
372+
end
373+
copy_chunk!(out, chunks[end], write_offset, offset-1)
374+
end
375+
376+
# Collapse t_chunks
377+
ts = Vector{T}(undef, chunks_len(ilsc.t_chunks, ilsc.t_offset))
378+
collapse_chunks!(ts, ilsc.t_chunks, ilsc.t_offset)
379+
380+
# Collapse u_chunks
381+
us = Vector{Matrix{S}}(undef, length(ilsc.u_chunks))
382+
for u_idx in 1:length(ilsc.u_chunks)
383+
us[u_idx] = Matrix{S}(undef, size(ilsc.u_chunks[u_idx][1],1), chunks_len(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx]))
384+
collapse_chunks!(us[u_idx], ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])
385+
end
386+
387+
time_mask = BitMatrix(undef, size(ilsc.time_masks[1], 1), chunks_len(ilsc.time_masks, ilsc.t_offset))
388+
collapse_chunks!(time_mask, ilsc.time_masks, ilsc.t_offset)
233389
end
234390

235391
# Sanity-check lengths
@@ -249,7 +405,24 @@ function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where
249405
throw(ArgumentError("Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens))"))
250406
end
251407

252-
# Update our struct, release the `ilsc`
408+
# Update our struct, release the `ilsc` and its caches
409+
for t_chunk in ilsc.t_chunks
410+
release!(ilsc.cache.t_chunks, t_chunk)
411+
end
412+
@assert is_fully_released(ilsc.cache.t_chunks)
413+
for u_idx in 1:length(ilsc.u_chunks)
414+
for u_chunk in ilsc.u_chunks[u_idx]
415+
release!(ilsc.cache.u_chunks, u_chunk)
416+
end
417+
end
418+
@assert is_fully_released(ilsc.cache.u_chunks)
419+
for time_mask in ilsc.time_masks
420+
release!(ilsc.cache.time_masks, time_mask)
421+
end
422+
@assert is_fully_released(ilsc.cache.time_masks)
423+
if ils.ilsc_cache_pool !== nothing
424+
release!(ils.ilsc_cache_pool, ilsc.cache)
425+
end
253426
ils.ilsc = nothing
254427
ils.ts = ts
255428
ils.us = us

0 commit comments

Comments
 (0)