Skip to content

Commit 11cda8c

Browse files
committed
Improve caching and dispatch of LinearizingSavingCallback
This adds a new type, `LinearizingSavingCallbackCache` and some sub-types to allow for efficient re-use of memory as the callback executes over the course of a solve, as well as re-use of that memory in future solves when operating on a large ensemble simulation. The top-level `LinearizingSavingCallbackCache` creates thread-safe cache pool objects that are then used to acquire thread-unsafe cache pool objects to be used within a single solve. Those thread-unsafe cache pool objects can then be released and acquired anew by the next solve. The thread-unsafe pool objects allow for acquisition of pieces of memory such as temporary `u` vectors (the recusrive nature of the `LinearizingSavingCallback` means that we must allocate unknown numbers of temporary `u` vectors) and chunks of `u` blocks that are then compacted into a single large matrix in the finalize method of the callback. All these pieces of memory are stored within that set of thread-unsafe caches, and these are released back to the top-level thread-safe cache pool, for the next solve to acquire and make use of those pieces of memory in the cache pool. Using these techniques, the solve time of a large ensemble simulation with low per-simulation computation has reduced dramatically. The simulation solves a butterworth 3rd-order filter circuit over a certain timespan, swept across different simulus frequencies and circuit parameters. The parameter sweep results in a 13500-element ensemble simulation, that when run with 8 threads on a M1 Pro takes: ``` 48.364827 seconds (625.86 M allocations: 19.472 GiB, 41.81% gc time, 0.17% compilation time) ``` Now, after these caching optimizations, we solve the same ensemble in: ``` 13.208123 seconds (166.76 M allocations: 7.621 GiB, 22.21% gc time, 0.61% compilation time) ``` As a side note, the size requirements of the raw linearized solution data itself is `1.04 GB`. In general, we expect to allocate somewhere between 2-3x the final output data to account for temporaries and inefficient sharing, so while there is still some more work to be done, this gets us significantly closer to minimal overhead.
1 parent 08933b5 commit 11cda8c

File tree

4 files changed

+456
-107
lines changed

4 files changed

+456
-107
lines changed

src/independentlylinearizedutils.jl

Lines changed: 201 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,130 @@ using SciMLBase
22

33
export IndependentlyLinearizedSolution
44

5+
6+
"""
7+
CachePool(T, alloc; thread_safe = true)
8+
9+
Simple memory-reusing cache that allows us to grow a cache and keep
10+
re-using those pieces of memory (in our case, typically `u` vectors)
11+
until the solve is finished. By default, this datastructure is made
12+
to be thread-safe by locking on every acquire and release, but it
13+
can be made thread-unsafe (and correspondingly faster) by passing
14+
`thread_safe = false` to the constructor.
15+
16+
While manual usage with `acquire!()` and `release!()` is possible,
17+
most users will want to use `@with_cache`, which provides lexically-
18+
scoped `acquire!()` and `release!()` usage automatically. Example:
19+
20+
```julia
21+
us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22+
@with_cache us u_prev begin
23+
@with_cache us u_next begin
24+
# perform tasks with these two `u` vectors
25+
end
26+
end
27+
```
28+
"""
29+
mutable struct CachePool{T, THREAD_SAFE}
30+
pool::Vector{T}
31+
alloc::Function
32+
lock::ReentrantLock
33+
num_alloced::Int
34+
35+
function CachePool(T, alloc::F; thread_safe::Bool = true) where {F}
36+
return new{T,Val{thread_safe}}(T[], alloc, ReentrantLock(), 0)
37+
end
38+
end
39+
const ThreadSafeCachePool{T} = CachePool{T,Val{true}}
40+
const ThreadUnsafeCachePool{T} = CachePool{T,Val{false}}
41+
42+
"""
43+
acquire!(cache::CachePool)
44+
45+
Returns a cached element of the cache pool, calling `cache.alloc()` if none
46+
are available.
47+
"""
48+
Base.@inline function acquire!(cache::CachePool{T}, _dummy = nothing) where {T}
49+
if isempty(cache.pool)
50+
cache.num_alloced += 1
51+
return cache.alloc()::T
52+
end
53+
return pop!(cache.pool)
54+
end
55+
56+
"""
57+
release!(cache::CachePool, val)
58+
59+
Returns the value `val` to the cache pool.
60+
"""
61+
Base.@inline function release!(cache::CachePool, val, _dummy = nothing)
62+
push!(cache.pool, val)
63+
end
64+
65+
# Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
66+
acquire!(cache::ThreadSafeCachePool) = @lock cache.lock acquire!(cache, nothing)
67+
release!(cache::ThreadSafeCachePool, val) = @lock cache.lock release!(cache, val, nothing)
68+
69+
macro with_cache(cache, name, body)
70+
return quote
71+
$(esc(name)) = acquire!($(esc(cache)))
72+
try
73+
$(esc(body))
74+
finally
75+
release!($(esc(cache)), $(esc(name)))
76+
end
77+
end
78+
end
79+
80+
81+
struct IndependentlyLinearizedSolutionChunksCache{T,S}
82+
t_chunks::ThreadUnsafeCachePool{Vector{T}}
83+
u_chunks::ThreadUnsafeCachePool{Matrix{S}}
84+
time_masks::ThreadUnsafeCachePool{BitMatrix}
85+
86+
function IndependentlyLinearizedSolutionChunksCache{T,S}(num_us::Int, num_derivatives::Int, chunk_size::Int) where {T,S}
87+
t_chunks_alloc = () -> Vector{T}(undef, chunk_size)
88+
u_chunks_alloc = () -> Matrix{S}(undef, num_derivatives+1, chunk_size)
89+
time_masks_alloc = () -> BitMatrix(undef, num_us, chunk_size)
90+
return new(
91+
CachePool(Vector{T}, t_chunks_alloc; thread_safe=false),
92+
CachePool(Matrix{S}, u_chunks_alloc; thread_safe=false),
93+
CachePool(BitMatrix, time_masks_alloc; thread_safe=false),
94+
)
95+
end
96+
end
97+
598
"""
699
IndependentlyLinearizedSolutionChunks
7100
8101
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9102
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10103
that the solve will generate.
11104
"""
12-
mutable struct IndependentlyLinearizedSolutionChunks{T, S}
105+
mutable struct IndependentlyLinearizedSolutionChunks{T, S, N}
13106
t_chunks::Vector{Vector{T}}
14107
u_chunks::Vector{Vector{Matrix{S}}}
15108
time_masks::Vector{BitMatrix}
16109

110+
# Temporary array that gets used by `get_chunks`
111+
last_chunks::Vector{Matrix{S}}
112+
17113
# Index of next write into the last chunk
18114
u_offsets::Vector{Int}
19115
t_offset::Int
20116

117+
cache::IndependentlyLinearizedSolutionChunksCache
118+
21119
function IndependentlyLinearizedSolutionChunks{T, S}(num_us::Int, num_derivatives::Int = 0,
22-
chunk_size::Int = 100) where {T, S}
23-
return new([Vector{T}(undef, chunk_size)],
24-
[[Matrix{S}(undef, num_derivatives+1, chunk_size)] for _ in 1:num_us],
25-
[BitMatrix(undef, num_us, chunk_size)],
26-
[1 for _ in 1:num_us],
27-
1,
28-
)
120+
chunk_size::Int = 512,
121+
cache::IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size)) where {T, S}
122+
t_chunks = [acquire!(cache.t_chunks)]
123+
u_chunks = [[acquire!(cache.u_chunks)] for _ in 1:num_us]
124+
time_masks = [acquire!(cache.time_masks)]
125+
last_chunks = [u_chunks[u_idx][1] for u_idx in 1:num_us]
126+
u_offsets = [1 for _ in 1:num_us]
127+
t_offset = 1
128+
return new{T,S,num_derivatives}(t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
29129
end
30130
end
31131

@@ -44,14 +144,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
44144
end
45145
return length(ilsc.u_chunks)
46146
end
147+
num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}) where {T,S,N} = N
47148

48-
function num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks)
49-
# If we've been finalized, just return `0` (which means only the primal)
50-
if isempty(ilsc.t_chunks)
51-
return 0
52-
end
53-
return size(first(first(ilsc.u_chunks)), 1) - 1
54-
end
55149

56150
function Base.isempty(ilsc::IndependentlyLinearizedSolutionChunks)
57151
return length(ilsc.t_chunks) == 1 && ilsc.t_offset == 1
@@ -61,24 +155,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
61155
# Check if we need to allocate new `t` chunk
62156
chunksize = chunk_size(ilsc)
63157
if ilsc.t_offset > chunksize
64-
push!(ilsc.t_chunks, Vector{T}(undef, chunksize))
65-
push!(ilsc.time_masks, BitMatrix(undef, length(ilsc.u_offsets), chunksize))
158+
push!(ilsc.t_chunks, acquire!(ilsc.cache.t_chunks))
159+
push!(ilsc.time_masks, acquire!(ilsc.cache.time_masks))
66160
ilsc.t_offset = 1
67161
end
68162

69163
# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
70164
for (u_idx, u_chunks) in enumerate(ilsc.u_chunks)
71165
if ilsc.u_offsets[u_idx] > chunksize
72-
push!(u_chunks, Matrix{S}(undef, num_derivatives(ilsc)+1, chunksize))
166+
push!(u_chunks, acquire!(ilsc.cache.u_chunks))
73167
ilsc.u_offsets[u_idx] = 1
74168
end
169+
ilsc.last_chunks[u_idx] = u_chunks[end]
75170
end
76171

77172
# return the last chunk for each
78173
return (
79174
ilsc.t_chunks[end],
80175
ilsc.time_masks[end],
81-
[u_chunks[end] for u_chunks in ilsc.u_chunks],
176+
ilsc.last_chunks,
82177
)
83178
end
84179

@@ -135,16 +230,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
135230
ts, time_mask, us = get_chunks(ilsc)
136231

137232
# Store into the chunks, gated by `u_mask`
138-
for u_idx in 1:size(u, 2)
233+
@inbounds for u_idx in 1:size(u, 2)
139234
if u_mask[u_idx]
140235
for deriv_idx in 1:size(u, 1)
141236
us[u_idx][deriv_idx, ilsc.u_offsets[u_idx]] = u[deriv_idx, u_idx]
142237
end
143238
ilsc.u_offsets[u_idx] += 1
144239
end
240+
241+
# Update our `time_mask` while we're at it
242+
time_mask[u_idx, ilsc.t_offset] = u_mask[u_idx]
145243
end
146244
ts[ilsc.t_offset] = t
147-
time_mask[:, ilsc.t_offset] .= u_mask
148245
ilsc.t_offset += 1
149246
end
150247

@@ -161,7 +258,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161258
of the state variables at all timepoints, as well as an efficient `sample!()`
162259
method that can sample at arbitrary timesteps.
163260
"""
164-
mutable struct IndependentlyLinearizedSolution{T, S}
261+
mutable struct IndependentlyLinearizedSolution{T, S, N}
165262
# All timepoints, shared by all `us`
166263
ts::Vector{T}
167264

@@ -174,32 +271,44 @@ mutable struct IndependentlyLinearizedSolution{T, S}
174271

175272
# Temporary object used during construction, will be set to `nothing` at the end.
176273
ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
274+
ilsc_cache_pool::Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177275
end
178276
# Helper function to create an ILS wrapped around an in-progress ILSC
179-
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S}) where {T,S}
180-
ils = IndependentlyLinearizedSolution(
277+
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}, cache_pool = nothing) where {T,S,N}
278+
return IndependentlyLinearizedSolution{T,S,N}(
181279
T[],
182280
Matrix{S}[],
183281
BitMatrix(undef, 0,0),
184282
ilsc,
283+
cache_pool,
185284
)
186-
return ils
187285
end
188286
# Automatically create an ILS wrapped around an ILSC from a `prob`
189-
function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem, num_derivatives = 0)
287+
function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem, num_derivatives = 0;
288+
cache_pool = nothing,
289+
chunk_size::Int = 512)
190290
T = eltype(prob.tspan)
291+
S = eltype(prob.u0)
191292
U = isnothing(prob.u0) ? Float64 : eltype(prob.u0)
192-
N = isnothing(prob.u0) ? 0 : length(prob.u0)
193-
chunks = IndependentlyLinearizedSolutionChunks{T,U}(N, num_derivatives)
194-
return IndependentlyLinearizedSolution(chunks)
293+
num_us = isnothing(prob.u0) ? 0 : length(prob.u0)
294+
if cache_pool === nothing
295+
cache_pool = CachePool(
296+
IndependentlyLinearizedSolutionChunksCache{T,S},
297+
() -> IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size);
298+
thread_safe = true,
299+
)
300+
end
301+
cache = acquire!(cache_pool)
302+
chunks = IndependentlyLinearizedSolutionChunks{T,U}(num_us, num_derivatives, chunk_size, cache)
303+
return IndependentlyLinearizedSolution(chunks, cache_pool)
195304
end
196305

197-
num_derivatives(ils::IndependentlyLinearizedSolution) = !isempty(ils.us) ? size(first(ils.us), 1) : 0
306+
num_derivatives(::IndependentlyLinearizedSolution{T,S,N}) where {T,S,N} = N
198307
num_us(ils::IndependentlyLinearizedSolution) = length(ils.us)
199308
Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask)
200309
Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts)
201310

202-
function finish!(ils::IndependentlyLinearizedSolution)
311+
function finish!(ils::IndependentlyLinearizedSolution{T,S}) where {T,S}
203312
function trim_chunk(chunks::Vector, offset)
204313
chunks = [chunk for chunk in chunks]
205314
if eltype(chunks) <: AbstractVector
@@ -216,10 +325,52 @@ function finish!(ils::IndependentlyLinearizedSolution)
216325
end
217326

218327
ilsc = ils.ilsc::IndependentlyLinearizedSolutionChunks
219-
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
220-
time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
221-
us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
222-
for u_idx in 1:length(ilsc.u_chunks)]
328+
329+
chunk_len(chunk) = size(chunk, ndims(chunk))
330+
function chunks_len(chunks::Vector, offset)
331+
len = 0
332+
for chunk_idx in 1:length(chunks)-1
333+
len += chunk_len(chunks[chunk_idx])
334+
end
335+
return len + offset - 1
336+
end
337+
338+
function copy_chunk!(out::Vector, in::Vector, out_offset::Int, len=chunk_len(in))
339+
for idx in 1:len
340+
out[idx+out_offset] = in[idx]
341+
end
342+
end
343+
function copy_chunk!(out::AbstractMatrix, in::AbstractMatrix, out_offset::Int, len=chunk_len(in))
344+
for zdx in 1:size(in, 1)
345+
for idx in 1:len
346+
out[zdx, idx+out_offset] = in[zdx, idx]
347+
end
348+
end
349+
end
350+
351+
function collapse_chunks!(out, chunks, offset::Int)
352+
write_offset = 0
353+
for chunk_idx in 1:(length(chunks)-1)
354+
chunk = chunks[chunk_idx]
355+
copy_chunk!(out, chunk, write_offset)
356+
write_offset += chunk_len(chunk)
357+
end
358+
copy_chunk!(out, chunks[end], write_offset, offset-1)
359+
end
360+
361+
# Collapse t_chunks
362+
ts = Vector{T}(undef, chunks_len(ilsc.t_chunks, ilsc.t_offset))
363+
collapse_chunks!(ts, ilsc.t_chunks, ilsc.t_offset)
364+
365+
# Collapse u_chunks
366+
us = Vector{Matrix{S}}(undef, length(ilsc.u_chunks))
367+
for u_idx in 1:length(ilsc.u_chunks)
368+
us[u_idx] = Matrix{S}(undef, size(ilsc.u_chunks[u_idx][1],1), chunks_len(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx]))
369+
collapse_chunks!(us[u_idx], ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])
370+
end
371+
372+
time_mask = BitMatrix(undef, size(ilsc.time_masks[1], 1), chunks_len(ilsc.time_masks, ilsc.t_offset))
373+
collapse_chunks!(time_mask, ilsc.time_masks, ilsc.t_offset)
223374

224375
# Sanity-check lengths
225376
if length(ts) != size(time_mask, 2)
@@ -238,7 +389,21 @@ function finish!(ils::IndependentlyLinearizedSolution)
238389
throw(ArgumentError("Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens))"))
239390
end
240391

241-
# Update our struct, release the `ilsc`
392+
# Update our struct, release the `ilsc` and its caches
393+
for t_chunk in ilsc.t_chunks
394+
release!(ilsc.cache.t_chunks, t_chunk)
395+
end
396+
for u_idx in 1:length(ilsc.u_chunks)
397+
for u_chunk in ilsc.u_chunks[u_idx]
398+
release!(ilsc.cache.u_chunks, u_chunk)
399+
end
400+
end
401+
for time_mask in ilsc.time_masks
402+
release!(ilsc.cache.time_masks, time_mask)
403+
end
404+
if ils.ilsc_cache_pool !== nothing
405+
release!(ils.ilsc_cache_pool, ilsc.cache)
406+
end
242407
ils.ilsc = nothing
243408
ils.ts = ts
244409
ils.us = us

0 commit comments

Comments
 (0)