@@ -2,30 +2,144 @@ using SciMLBase
2
2
3
3
export IndependentlyLinearizedSolution
4
4
5
+
6
+ """
7
+ CachePool(T, alloc; thread_safe = true)
8
+
9
+ Simple memory-reusing cache that allows us to grow a cache and keep
10
+ re-using those pieces of memory (in our case, typically `u` vectors)
11
+ until the solve is finished. By default, this datastructure is made
12
+ to be thread-safe by locking on every acquire and release, but it
13
+ can be made thread-unsafe (and correspondingly faster) by passing
14
+ `thread_safe = false` to the constructor.
15
+
16
+ While manual usage with `acquire!()` and `release!()` is possible,
17
+ most users will want to use `@with_cache`, which provides lexically-
18
+ scoped `acquire!()` and `release!()` usage automatically. Example:
19
+
20
+ ```julia
21
+ us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22
+ @with_cache us u_prev begin
23
+ @with_cache us u_next begin
24
+ # perform tasks with these two `u` vectors
25
+ end
26
+ end
27
+ ```
28
+
29
+ !!! warning "Escaping values"
30
+ You must not use an acquired value after you have released it;
31
+ the memory may be immediately re-used by some other consumer of
32
+ your cache pool. Do not allow the acquired value to escape
33
+ outside of the `@with_cache` block, or past a `release!()`.
34
+ """
35
+ mutable struct CachePool{T, THREAD_SAFE}
36
+ const pool:: Vector{T}
37
+ const alloc:: Function
38
+ lock:: ReentrantLock
39
+ num_allocated:: Int
40
+ num_acquired:: Int
41
+
42
+ function CachePool (T, alloc:: F ; thread_safe:: Bool = true ) where {F}
43
+ return new {T,Val{thread_safe}} (T[], alloc, ReentrantLock (), 0 , 0 )
44
+ end
45
+ end
46
+ const ThreadSafeCachePool{T} = CachePool{T,Val{true }}
47
+ const ThreadUnsafeCachePool{T} = CachePool{T,Val{false }}
48
+
49
+ """
50
+ acquire!(cache::CachePool)
51
+
52
+ Returns a cached element of the cache pool, calling `cache.alloc()` if none
53
+ are available.
54
+ """
55
+ Base. @inline function acquire! (cache:: CachePool{T} , _dummy = nothing ) where {T}
56
+ cache. num_acquired += 1
57
+ if isempty (cache. pool)
58
+ cache. num_allocated += 1
59
+ return cache. alloc ():: T
60
+ end
61
+ return pop! (cache. pool)
62
+ end
63
+
64
+ """
65
+ release!(cache::CachePool, val)
66
+
67
+ Returns the value `val` to the cache pool.
68
+ """
69
+ Base. @inline function release! (cache:: CachePool , val, _dummy = nothing )
70
+ push! (cache. pool, val)
71
+ cache. num_acquired -= 1
72
+ end
73
+
74
+ function is_fully_released (cache:: CachePool , _dummy = nothing )
75
+ return cache. num_acquired == 0
76
+ end
77
+
78
+ # Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
79
+ acquire! (cache:: ThreadSafeCachePool ) = @lock cache. lock acquire! (cache, nothing )
80
+ release! (cache:: ThreadSafeCachePool , val) = @lock cache. lock release! (cache, val, nothing )
81
+ is_fully_released (cache:: ThreadSafeCachePool ) = @lock cache. lock is_fully_released (cache, nothing )
82
+
83
+ macro with_cache (cache, name, body)
84
+ return quote
85
+ $ (esc (name)) = acquire! ($ (esc (cache)))
86
+ try
87
+ $ (esc (body))
88
+ finally
89
+ release! ($ (esc (cache)), $ (esc (name)))
90
+ end
91
+ end
92
+ end
93
+
94
+
95
+ struct IndependentlyLinearizedSolutionChunksCache{T,S}
96
+ t_chunks:: ThreadUnsafeCachePool{Vector{T}}
97
+ u_chunks:: ThreadUnsafeCachePool{Matrix{S}}
98
+ time_masks:: ThreadUnsafeCachePool{BitMatrix}
99
+
100
+ function IndependentlyLinearizedSolutionChunksCache {T,S} (num_us:: Int , num_derivatives:: Int , chunk_size:: Int ) where {T,S}
101
+ t_chunks_alloc = () -> Vector {T} (undef, chunk_size)
102
+ u_chunks_alloc = () -> Matrix {S} (undef, num_derivatives+ 1 , chunk_size)
103
+ time_masks_alloc = () -> BitMatrix (undef, num_us, chunk_size)
104
+ return new (
105
+ CachePool (Vector{T}, t_chunks_alloc; thread_safe= false ),
106
+ CachePool (Matrix{S}, u_chunks_alloc; thread_safe= false ),
107
+ CachePool (BitMatrix, time_masks_alloc; thread_safe= false ),
108
+ )
109
+ end
110
+ end
111
+
5
112
"""
6
113
IndependentlyLinearizedSolutionChunks
7
114
8
115
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9
116
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10
117
that the solve will generate.
11
118
"""
12
- mutable struct IndependentlyLinearizedSolutionChunks{T, S}
119
+ mutable struct IndependentlyLinearizedSolutionChunks{T, S, N }
13
120
t_chunks:: Vector{Vector{T}}
14
121
u_chunks:: Vector{Vector{Matrix{S}}}
15
122
time_masks:: Vector{BitMatrix}
16
123
124
+ # Temporary array that gets used by `get_chunks`
125
+ last_chunks:: Vector{Matrix{S}}
126
+
17
127
# Index of next write into the last chunk
18
128
u_offsets:: Vector{Int}
19
129
t_offset:: Int
20
130
131
+ cache:: IndependentlyLinearizedSolutionChunksCache
132
+
21
133
function IndependentlyLinearizedSolutionChunks {T, S} (num_us:: Int , num_derivatives:: Int = 0 ,
22
- chunk_size:: Int = 100 ) where {T, S}
23
- return new ([Vector {T} (undef, chunk_size)],
24
- [[Matrix {S} (undef, num_derivatives+ 1 , chunk_size)] for _ in 1 : num_us],
25
- [BitMatrix (undef, num_us, chunk_size)],
26
- [1 for _ in 1 : num_us],
27
- 1 ,
28
- )
134
+ chunk_size:: Int = 512 ,
135
+ cache:: IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)) where {T, S}
136
+ t_chunks = [acquire! (cache. t_chunks)]
137
+ u_chunks = [[acquire! (cache. u_chunks)] for _ in 1 : num_us]
138
+ time_masks = [acquire! (cache. time_masks)]
139
+ last_chunks = [u_chunks[u_idx][1 ] for u_idx in 1 : num_us]
140
+ u_offsets = [1 for _ in 1 : num_us]
141
+ t_offset = 1
142
+ return new {T,S,num_derivatives} (t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
29
143
end
30
144
end
31
145
@@ -44,14 +158,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
44
158
end
45
159
return length (ilsc. u_chunks)
46
160
end
161
+ num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} ) where {T,S,N} = N
47
162
48
- function num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks )
49
- # If we've been finalized, just return `0` (which means only the primal)
50
- if isempty (ilsc. t_chunks)
51
- return 0
52
- end
53
- return size (first (first (ilsc. u_chunks)), 1 ) - 1
54
- end
55
163
56
164
function Base. isempty (ilsc:: IndependentlyLinearizedSolutionChunks )
57
165
return length (ilsc. t_chunks) == 1 && ilsc. t_offset == 1
@@ -61,24 +169,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
61
169
# Check if we need to allocate new `t` chunk
62
170
chunksize = chunk_size (ilsc)
63
171
if ilsc. t_offset > chunksize
64
- push! (ilsc. t_chunks, Vector {T} (undef, chunksize ))
65
- push! (ilsc. time_masks, BitMatrix (undef, length ( ilsc. u_offsets), chunksize ))
172
+ push! (ilsc. t_chunks, acquire! (ilsc . cache . t_chunks ))
173
+ push! (ilsc. time_masks, acquire! ( ilsc. cache . time_masks ))
66
174
ilsc. t_offset = 1
67
175
end
68
176
69
177
# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
70
178
for (u_idx, u_chunks) in enumerate (ilsc. u_chunks)
71
179
if ilsc. u_offsets[u_idx] > chunksize
72
- push! (u_chunks, Matrix {S} (undef, num_derivatives ( ilsc) + 1 , chunksize ))
180
+ push! (u_chunks, acquire! ( ilsc. cache . u_chunks ))
73
181
ilsc. u_offsets[u_idx] = 1
74
182
end
183
+ ilsc. last_chunks[u_idx] = u_chunks[end ]
75
184
end
76
185
77
186
# return the last chunk for each
78
187
return (
79
188
ilsc. t_chunks[end ],
80
189
ilsc. time_masks[end ],
81
- [u_chunks[ end ] for u_chunks in ilsc. u_chunks] ,
190
+ ilsc. last_chunks ,
82
191
)
83
192
end
84
193
@@ -135,16 +244,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
135
244
ts, time_mask, us = get_chunks (ilsc)
136
245
137
246
# Store into the chunks, gated by `u_mask`
138
- for u_idx in 1 : size (u, 2 )
247
+ @inbounds for u_idx in 1 : size (u, 2 )
139
248
if u_mask[u_idx]
140
249
for deriv_idx in 1 : size (u, 1 )
141
250
us[u_idx][deriv_idx, ilsc. u_offsets[u_idx]] = u[deriv_idx, u_idx]
142
251
end
143
252
ilsc. u_offsets[u_idx] += 1
144
253
end
254
+
255
+ # Update our `time_mask` while we're at it
256
+ time_mask[u_idx, ilsc. t_offset] = u_mask[u_idx]
145
257
end
146
258
ts[ilsc. t_offset] = t
147
- time_mask[:, ilsc. t_offset] .= u_mask
148
259
ilsc. t_offset += 1
149
260
end
150
261
@@ -161,7 +272,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161
272
of the state variables at all timepoints, as well as an efficient `sample!()`
162
273
method that can sample at arbitrary timesteps.
163
274
"""
164
- mutable struct IndependentlyLinearizedSolution{T, S}
275
+ mutable struct IndependentlyLinearizedSolution{T, S, N }
165
276
# All timepoints, shared by all `us`
166
277
ts:: Vector{T}
167
278
@@ -173,33 +284,42 @@ mutable struct IndependentlyLinearizedSolution{T, S}
173
284
time_mask:: BitMatrix
174
285
175
286
# Temporary object used during construction, will be set to `nothing` at the end.
176
- ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
287
+ ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S,N}}
288
+ ilsc_cache_pool:: Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177
289
end
178
290
# Helper function to create an ILS wrapped around an in-progress ILSC
179
- function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S} ) where {T,S}
180
- ils = IndependentlyLinearizedSolution (
291
+ function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} , cache_pool = nothing ) where {T,S,N }
292
+ return IndependentlyLinearizedSolution {T,S,N} (
181
293
T[],
182
294
Matrix{S}[],
183
295
BitMatrix (undef, 0 ,0 ),
184
296
ilsc,
297
+ cache_pool,
185
298
)
186
- return ils
187
299
end
188
300
# Automatically create an ILS wrapped around an ILSC from a `prob`
189
- function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 )
301
+ function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 ;
302
+ cache_pool = nothing ,
303
+ chunk_size:: Int = 512 )
190
304
T = eltype (prob. tspan)
305
+ S = eltype (prob. u0)
191
306
U = isnothing (prob. u0) ? Float64 : eltype (prob. u0)
192
- N = isnothing (prob. u0) ? 0 : length (prob. u0)
193
- chunks = IndependentlyLinearizedSolutionChunks {T,U} (N, num_derivatives)
194
- return IndependentlyLinearizedSolution (chunks)
307
+ num_us = isnothing (prob. u0) ? 0 : length (prob. u0)
308
+ if cache_pool === nothing
309
+ cache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)
310
+ else
311
+ cache = acquire! (cache_pool)
312
+ end
313
+ chunks = IndependentlyLinearizedSolutionChunks {T,U} (num_us, num_derivatives, chunk_size, cache)
314
+ return IndependentlyLinearizedSolution (chunks, cache_pool)
195
315
end
196
316
197
- num_derivatives (ils :: IndependentlyLinearizedSolution ) = ! isempty (ils . us) ? size ( first (ils . us), 1 ) : 0
317
+ num_derivatives (:: IndependentlyLinearizedSolution{T,S,N} ) where {T,S,N} = N
198
318
num_us (ils:: IndependentlyLinearizedSolution ) = length (ils. us)
199
319
Base. size (ils:: IndependentlyLinearizedSolution ) = size (ils. time_mask)
200
320
Base. length (ils:: IndependentlyLinearizedSolution ) = length (ils. ts)
201
321
202
- function finish! (ils:: IndependentlyLinearizedSolution )
322
+ function finish! (ils:: IndependentlyLinearizedSolution{T,S} ) where {T,S}
203
323
function trim_chunk (chunks:: Vector , offset)
204
324
chunks = [chunk for chunk in chunks]
205
325
if eltype (chunks) <: AbstractVector
@@ -216,10 +336,52 @@ function finish!(ils::IndependentlyLinearizedSolution)
216
336
end
217
337
218
338
ilsc = ils. ilsc:: IndependentlyLinearizedSolutionChunks
219
- ts = vcat (trim_chunk (ilsc. t_chunks, ilsc. t_offset)... )
220
- time_mask = hcat (trim_chunk (ilsc. time_masks, ilsc. t_offset)... )
221
- us = [hcat (trim_chunk (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])... )
222
- for u_idx in 1 : length (ilsc. u_chunks)]
339
+
340
+ chunk_len (chunk) = size (chunk, ndims (chunk))
341
+ function chunks_len (chunks:: Vector , offset)
342
+ len = 0
343
+ for chunk_idx in 1 : length (chunks)- 1
344
+ len += chunk_len (chunks[chunk_idx])
345
+ end
346
+ return len + offset - 1
347
+ end
348
+
349
+ function copy_chunk! (out:: Vector , in:: Vector , out_offset:: Int , len= chunk_len (in))
350
+ for idx in 1 : len
351
+ out[idx+ out_offset] = in[idx]
352
+ end
353
+ end
354
+ function copy_chunk! (out:: AbstractMatrix , in:: AbstractMatrix , out_offset:: Int , len= chunk_len (in))
355
+ for zdx in 1 : size (in, 1 )
356
+ for idx in 1 : len
357
+ out[zdx, idx+ out_offset] = in[zdx, idx]
358
+ end
359
+ end
360
+ end
361
+
362
+ function collapse_chunks! (out, chunks, offset:: Int )
363
+ write_offset = 0
364
+ for chunk_idx in 1 : (length (chunks)- 1 )
365
+ chunk = chunks[chunk_idx]
366
+ copy_chunk! (out, chunk, write_offset)
367
+ write_offset += chunk_len (chunk)
368
+ end
369
+ copy_chunk! (out, chunks[end ], write_offset, offset- 1 )
370
+ end
371
+
372
+ # Collapse t_chunks
373
+ ts = Vector {T} (undef, chunks_len (ilsc. t_chunks, ilsc. t_offset))
374
+ collapse_chunks! (ts, ilsc. t_chunks, ilsc. t_offset)
375
+
376
+ # Collapse u_chunks
377
+ us = Vector {Matrix{S}} (undef, length (ilsc. u_chunks))
378
+ for u_idx in 1 : length (ilsc. u_chunks)
379
+ us[u_idx] = Matrix {S} (undef, size (ilsc. u_chunks[u_idx][1 ],1 ), chunks_len (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx]))
380
+ collapse_chunks! (us[u_idx], ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])
381
+ end
382
+
383
+ time_mask = BitMatrix (undef, size (ilsc. time_masks[1 ], 1 ), chunks_len (ilsc. time_masks, ilsc. t_offset))
384
+ collapse_chunks! (time_mask, ilsc. time_masks, ilsc. t_offset)
223
385
224
386
# Sanity-check lengths
225
387
if length (ts) != size (time_mask, 2 )
@@ -238,7 +400,24 @@ function finish!(ils::IndependentlyLinearizedSolution)
238
400
throw (ArgumentError (" Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens) )" ))
239
401
end
240
402
241
- # Update our struct, release the `ilsc`
403
+ # Update our struct, release the `ilsc` and its caches
404
+ for t_chunk in ilsc. t_chunks
405
+ release! (ilsc. cache. t_chunks, t_chunk)
406
+ end
407
+ @assert is_fully_released (ilsc. cache. t_chunks)
408
+ for u_idx in 1 : length (ilsc. u_chunks)
409
+ for u_chunk in ilsc. u_chunks[u_idx]
410
+ release! (ilsc. cache. u_chunks, u_chunk)
411
+ end
412
+ end
413
+ @assert is_fully_released (ilsc. cache. u_chunks)
414
+ for time_mask in ilsc. time_masks
415
+ release! (ilsc. cache. time_masks, time_mask)
416
+ end
417
+ @assert is_fully_released (ilsc. cache. time_masks)
418
+ if ils. ilsc_cache_pool != = nothing
419
+ release! (ils. ilsc_cache_pool, ilsc. cache)
420
+ end
242
421
ils. ilsc = nothing
243
422
ils. ts = ts
244
423
ils. us = us
0 commit comments