@@ -2,31 +2,144 @@ using SciMLBase
2
2
3
3
export IndependentlyLinearizedSolution
4
4
5
+
6
+ """
7
+ CachePool(T, alloc; thread_safe = true)
8
+
9
+ Simple memory-reusing cache that allows us to grow a cache and keep
10
+ re-using those pieces of memory (in our case, typically `u` vectors)
11
+ until the solve is finished. By default, this datastructure is made
12
+ to be thread-safe by locking on every acquire and release, but it
13
+ can be made thread-unsafe (and correspondingly faster) by passing
14
+ `thread_safe = false` to the constructor.
15
+
16
+ While manual usage with `acquire!()` and `release!()` is possible,
17
+ most users will want to use `@with_cache`, which provides lexically-
18
+ scoped `acquire!()` and `release!()` usage automatically. Example:
19
+
20
+ ```julia
21
+ us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22
+ @with_cache us u_prev begin
23
+ @with_cache us u_next begin
24
+ # perform tasks with these two `u` vectors
25
+ end
26
+ end
27
+ ```
28
+
29
+ !!! warning "Escaping values"
30
+ You must not use an acquired value after you have released it;
31
+ the memory may be immediately re-used by some other consumer of
32
+ your cache pool. Do not allow the acquired value to escape
33
+ outside of the `@with_cache` block, or past a `release!()`.
34
+ """
35
+ mutable struct CachePool{T, THREAD_SAFE}
36
+ const pool:: Vector{T}
37
+ const alloc:: Function
38
+ lock:: ReentrantLock
39
+ num_allocated:: Int
40
+ num_acquired:: Int
41
+
42
+ function CachePool (T, alloc:: F ; thread_safe:: Bool = true ) where {F}
43
+ return new {T,Val{thread_safe}} (T[], alloc, ReentrantLock (), 0 , 0 )
44
+ end
45
+ end
46
+ const ThreadSafeCachePool{T} = CachePool{T,Val{true }}
47
+ const ThreadUnsafeCachePool{T} = CachePool{T,Val{false }}
48
+
49
+ """
50
+ acquire!(cache::CachePool)
51
+
52
+ Returns a cached element of the cache pool, calling `cache.alloc()` if none
53
+ are available.
54
+ """
55
+ Base. @inline function acquire! (cache:: CachePool{T} , _dummy = nothing ) where {T}
56
+ cache. num_acquired += 1
57
+ if isempty (cache. pool)
58
+ cache. num_allocated += 1
59
+ return cache. alloc ():: T
60
+ end
61
+ return pop! (cache. pool)
62
+ end
63
+
64
+ """
65
+ release!(cache::CachePool, val)
66
+
67
+ Returns the value `val` to the cache pool.
68
+ """
69
+ Base. @inline function release! (cache:: CachePool , val, _dummy = nothing )
70
+ push! (cache. pool, val)
71
+ cache. num_acquired -= 1
72
+ end
73
+
74
+ function is_fully_released (cache:: CachePool , _dummy = nothing )
75
+ return cache. num_acquired == 0
76
+ end
77
+
78
+ # Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
79
+ acquire! (cache:: ThreadSafeCachePool ) = @lock cache. lock acquire! (cache, nothing )
80
+ release! (cache:: ThreadSafeCachePool , val) = @lock cache. lock release! (cache, val, nothing )
81
+ is_fully_released (cache:: ThreadSafeCachePool ) = @lock cache. lock is_fully_released (cache, nothing )
82
+
83
+ macro with_cache (cache, name, body)
84
+ return quote
85
+ $ (esc (name)) = acquire! ($ (esc (cache)))
86
+ try
87
+ $ (esc (body))
88
+ finally
89
+ release! ($ (esc (cache)), $ (esc (name)))
90
+ end
91
+ end
92
+ end
93
+
94
+
95
+ struct IndependentlyLinearizedSolutionChunksCache{T,S}
96
+ t_chunks:: ThreadUnsafeCachePool{Vector{T}}
97
+ u_chunks:: ThreadUnsafeCachePool{Matrix{S}}
98
+ time_masks:: ThreadUnsafeCachePool{BitMatrix}
99
+
100
+ function IndependentlyLinearizedSolutionChunksCache {T,S} (num_us:: Int , num_derivatives:: Int , chunk_size:: Int ) where {T,S}
101
+ t_chunks_alloc = () -> Vector {T} (undef, chunk_size)
102
+ u_chunks_alloc = () -> Matrix {S} (undef, num_derivatives+ 1 , chunk_size)
103
+ time_masks_alloc = () -> BitMatrix (undef, num_us, chunk_size)
104
+ return new (
105
+ CachePool (Vector{T}, t_chunks_alloc; thread_safe= false ),
106
+ CachePool (Matrix{S}, u_chunks_alloc; thread_safe= false ),
107
+ CachePool (BitMatrix, time_masks_alloc; thread_safe= false ),
108
+ )
109
+ end
110
+ end
111
+
5
112
"""
6
113
IndependentlyLinearizedSolutionChunks
7
114
8
115
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9
116
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10
117
that the solve will generate.
11
118
"""
12
- mutable struct IndependentlyLinearizedSolutionChunks{T, S}
119
+ mutable struct IndependentlyLinearizedSolutionChunks{T, S, N }
13
120
t_chunks:: Vector{Vector{T}}
14
121
u_chunks:: Vector{Vector{Matrix{S}}}
15
122
time_masks:: Vector{BitMatrix}
16
123
124
+ # Temporary array that gets used by `get_chunks`
125
+ last_chunks:: Vector{Matrix{S}}
126
+
17
127
# Index of next write into the last chunk
18
128
u_offsets:: Vector{Int}
19
129
t_offset:: Int
20
130
21
- function IndependentlyLinearizedSolutionChunks {T, S} (
22
- num_us:: Int , num_derivatives:: Int = 0 ,
23
- chunk_size:: Int = 100 ) where {T, S}
24
- return new ([Vector {T} (undef, chunk_size)],
25
- [[Matrix {S} (undef, num_derivatives + 1 , chunk_size)] for _ in 1 : num_us],
26
- [BitMatrix (undef, num_us, chunk_size)],
27
- [1 for _ in 1 : num_us],
28
- 1
29
- )
131
+ cache:: IndependentlyLinearizedSolutionChunksCache
132
+
133
+ function IndependentlyLinearizedSolutionChunks {T, S} (num_us:: Int , num_derivatives:: Int = 0 ,
134
+ chunk_size:: Int = 512 ,
135
+ cache:: IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)) where {T, S}
136
+ t_chunks = [acquire! (cache. t_chunks)]
137
+ u_chunks = [[acquire! (cache. u_chunks)] for _ in 1 : num_us]
138
+ time_masks = [acquire! (cache. time_masks)]
139
+ last_chunks = [u_chunks[u_idx][1 ] for u_idx in 1 : num_us]
140
+ u_offsets = [1 for _ in 1 : num_us]
141
+ t_offset = 1
142
+ return new {T,S,num_derivatives} (t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
30
143
end
31
144
end
32
145
@@ -45,14 +158,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
45
158
end
46
159
return length (ilsc. u_chunks)
47
160
end
161
+ num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} ) where {T,S,N} = N
48
162
49
- function num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks )
50
- # If we've been finalized, just return `0` (which means only the primal)
51
- if isempty (ilsc. t_chunks)
52
- return 0
53
- end
54
- return size (first (first (ilsc. u_chunks)), 1 ) - 1
55
- end
56
163
57
164
function Base. isempty (ilsc:: IndependentlyLinearizedSolutionChunks )
58
165
return length (ilsc. t_chunks) == 1 && ilsc. t_offset == 1
@@ -62,24 +169,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
62
169
# Check if we need to allocate new `t` chunk
63
170
chunksize = chunk_size (ilsc)
64
171
if ilsc. t_offset > chunksize
65
- push! (ilsc. t_chunks, Vector {T} (undef, chunksize ))
66
- push! (ilsc. time_masks, BitMatrix (undef, length ( ilsc. u_offsets), chunksize ))
172
+ push! (ilsc. t_chunks, acquire! (ilsc . cache . t_chunks ))
173
+ push! (ilsc. time_masks, acquire! ( ilsc. cache . time_masks ))
67
174
ilsc. t_offset = 1
68
175
end
69
176
70
177
# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
71
178
for (u_idx, u_chunks) in enumerate (ilsc. u_chunks)
72
179
if ilsc. u_offsets[u_idx] > chunksize
73
- push! (u_chunks, Matrix {S} (undef, num_derivatives ( ilsc) + 1 , chunksize ))
180
+ push! (u_chunks, acquire! ( ilsc. cache . u_chunks ))
74
181
ilsc. u_offsets[u_idx] = 1
75
182
end
183
+ ilsc. last_chunks[u_idx] = u_chunks[end ]
76
184
end
77
185
78
186
# return the last chunk for each
79
187
return (
80
188
ilsc. t_chunks[end ],
81
189
ilsc. time_masks[end ],
82
- [u_chunks[ end ] for u_chunks in ilsc. u_chunks]
190
+ ilsc. last_chunks,
83
191
)
84
192
end
85
193
@@ -137,16 +245,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
137
245
ts, time_mask, us = get_chunks (ilsc)
138
246
139
247
# Store into the chunks, gated by `u_mask`
140
- for u_idx in 1 : size (u, 2 )
248
+ @inbounds for u_idx in 1 : size (u, 2 )
141
249
if u_mask[u_idx]
142
250
for deriv_idx in 1 : size (u, 1 )
143
251
us[u_idx][deriv_idx, ilsc. u_offsets[u_idx]] = u[deriv_idx, u_idx]
144
252
end
145
253
ilsc. u_offsets[u_idx] += 1
146
254
end
255
+
256
+ # Update our `time_mask` while we're at it
257
+ time_mask[u_idx, ilsc. t_offset] = u_mask[u_idx]
147
258
end
148
259
ts[ilsc. t_offset] = t
149
- time_mask[:, ilsc. t_offset] .= u_mask
150
260
ilsc. t_offset += 1
151
261
end
152
262
@@ -161,7 +271,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161
271
of the state variables at all timepoints, as well as an efficient `sample!()`
162
272
method that can sample at arbitrary timesteps.
163
273
"""
164
- mutable struct IndependentlyLinearizedSolution{T, S}
274
+ mutable struct IndependentlyLinearizedSolution{T, S, N }
165
275
# All timepoints, shared by all `us`
166
276
ts:: Vector{T}
167
277
@@ -173,32 +283,37 @@ mutable struct IndependentlyLinearizedSolution{T, S}
173
283
time_mask:: BitMatrix
174
284
175
285
# Temporary object used during construction, will be set to `nothing` at the end.
176
- ilsc:: Union{Nothing, IndependentlyLinearizedSolutionChunks{T, S}}
286
+ ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S,N}}
287
+ ilsc_cache_pool:: Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177
288
end
178
289
# Helper function to create an ILS wrapped around an in-progress ILSC
179
- function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks {
180
- T, S}) where {T, S}
181
- ils = IndependentlyLinearizedSolution (
290
+ function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} , cache_pool = nothing ) where {T,S,N}
291
+ return IndependentlyLinearizedSolution {T,S,N} (
182
292
T[],
183
293
Matrix{S}[],
184
- BitMatrix (undef, 0 , 0 ),
185
- ilsc
294
+ BitMatrix (undef, 0 ,0 ),
295
+ ilsc,
296
+ cache_pool,
186
297
)
187
- return ils
188
298
end
189
299
# Automatically create an ILS wrapped around an ILSC from a `prob`
190
- function IndependentlyLinearizedSolution (
191
- prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 )
300
+ function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 ;
301
+ cache_pool = nothing ,
302
+ chunk_size:: Int = 512 )
192
303
T = eltype (prob. tspan)
304
+ S = eltype (prob. u0)
193
305
U = isnothing (prob. u0) ? Float64 : eltype (prob. u0)
194
- N = isnothing (prob. u0) ? 0 : length (prob. u0)
195
- chunks = IndependentlyLinearizedSolutionChunks {T, U} (N, num_derivatives)
196
- return IndependentlyLinearizedSolution (chunks)
306
+ num_us = isnothing (prob. u0) ? 0 : length (prob. u0)
307
+ if cache_pool === nothing
308
+ cache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)
309
+ else
310
+ cache = acquire! (cache_pool)
311
+ end
312
+ chunks = IndependentlyLinearizedSolutionChunks {T,U} (num_us, num_derivatives, chunk_size, cache)
313
+ return IndependentlyLinearizedSolution (chunks, cache_pool)
197
314
end
198
315
199
- function num_derivatives (ils:: IndependentlyLinearizedSolution )
200
- ! isempty (ils. us) ? size (first (ils. us), 1 ) : 0
201
- end
316
+ num_derivatives (:: IndependentlyLinearizedSolution{T,S,N} ) where {T,S,N} = N
202
317
num_us (ils:: IndependentlyLinearizedSolution ) = length (ils. us)
203
318
Base. size (ils:: IndependentlyLinearizedSolution ) = size (ils. time_mask)
204
319
Base. length (ils:: IndependentlyLinearizedSolution ) = length (ils. ts)
@@ -226,10 +341,51 @@ function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where
226
341
us = Vector {Matrix{S}} ()
227
342
time_mask = BitMatrix (undef, 0 , 0 )
228
343
else
229
- ts = vcat (trim_chunk (ilsc. t_chunks, ilsc. t_offset)... )
230
- time_mask = hcat (trim_chunk (ilsc. time_masks, ilsc. t_offset)... )
231
- us = [hcat (trim_chunk (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])... )
232
- for u_idx in 1 : length (ilsc. u_chunks)]
344
+ chunk_len (chunk) = size (chunk, ndims (chunk))
345
+ function chunks_len (chunks:: Vector , offset)
346
+ len = 0
347
+ for chunk_idx in 1 : length (chunks)- 1
348
+ len += chunk_len (chunks[chunk_idx])
349
+ end
350
+ return len + offset - 1
351
+ end
352
+
353
+ function copy_chunk! (out:: Vector , in:: Vector , out_offset:: Int , len= chunk_len (in))
354
+ for idx in 1 : len
355
+ out[idx+ out_offset] = in[idx]
356
+ end
357
+ end
358
+ function copy_chunk! (out:: AbstractMatrix , in:: AbstractMatrix , out_offset:: Int , len= chunk_len (in))
359
+ for zdx in 1 : size (in, 1 )
360
+ for idx in 1 : len
361
+ out[zdx, idx+ out_offset] = in[zdx, idx]
362
+ end
363
+ end
364
+ end
365
+
366
+ function collapse_chunks! (out, chunks, offset:: Int )
367
+ write_offset = 0
368
+ for chunk_idx in 1 : (length (chunks)- 1 )
369
+ chunk = chunks[chunk_idx]
370
+ copy_chunk! (out, chunk, write_offset)
371
+ write_offset += chunk_len (chunk)
372
+ end
373
+ copy_chunk! (out, chunks[end ], write_offset, offset- 1 )
374
+ end
375
+
376
+ # Collapse t_chunks
377
+ ts = Vector {T} (undef, chunks_len (ilsc. t_chunks, ilsc. t_offset))
378
+ collapse_chunks! (ts, ilsc. t_chunks, ilsc. t_offset)
379
+
380
+ # Collapse u_chunks
381
+ us = Vector {Matrix{S}} (undef, length (ilsc. u_chunks))
382
+ for u_idx in 1 : length (ilsc. u_chunks)
383
+ us[u_idx] = Matrix {S} (undef, size (ilsc. u_chunks[u_idx][1 ],1 ), chunks_len (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx]))
384
+ collapse_chunks! (us[u_idx], ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])
385
+ end
386
+
387
+ time_mask = BitMatrix (undef, size (ilsc. time_masks[1 ], 1 ), chunks_len (ilsc. time_masks, ilsc. t_offset))
388
+ collapse_chunks! (time_mask, ilsc. time_masks, ilsc. t_offset)
233
389
end
234
390
235
391
# Sanity-check lengths
@@ -249,7 +405,24 @@ function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where
249
405
throw (ArgumentError (" Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens) )" ))
250
406
end
251
407
252
- # Update our struct, release the `ilsc`
408
+ # Update our struct, release the `ilsc` and its caches
409
+ for t_chunk in ilsc. t_chunks
410
+ release! (ilsc. cache. t_chunks, t_chunk)
411
+ end
412
+ @assert is_fully_released (ilsc. cache. t_chunks)
413
+ for u_idx in 1 : length (ilsc. u_chunks)
414
+ for u_chunk in ilsc. u_chunks[u_idx]
415
+ release! (ilsc. cache. u_chunks, u_chunk)
416
+ end
417
+ end
418
+ @assert is_fully_released (ilsc. cache. u_chunks)
419
+ for time_mask in ilsc. time_masks
420
+ release! (ilsc. cache. time_masks, time_mask)
421
+ end
422
+ @assert is_fully_released (ilsc. cache. time_masks)
423
+ if ils. ilsc_cache_pool != = nothing
424
+ release! (ils. ilsc_cache_pool, ilsc. cache)
425
+ end
253
426
ils. ilsc = nothing
254
427
ils. ts = ts
255
428
ils. us = us
0 commit comments