@@ -2,30 +2,130 @@ using SciMLBase
2
2
3
3
export IndependentlyLinearizedSolution
4
4
5
+
6
+ """
7
+ CachePool(T, alloc; thread_safe = true)
8
+
9
+ Simple memory-reusing cache that allows us to grow a cache and keep
10
+ re-using those pieces of memory (in our case, typically `u` vectors)
11
+ until the solve is finished. By default, this datastructure is made
12
+ to be thread-safe by locking on every acquire and release, but it
13
+ can be made thread-unsafe (and correspondingly faster) by passing
14
+ `thread_safe = false` to the constructor.
15
+
16
+ While manual usage with `acquire!()` and `release!()` is possible,
17
+ most users will want to use `@with_cache`, which provides lexically-
18
+ scoped `acquire!()` and `release!()` usage automatically. Example:
19
+
20
+ ```julia
21
+ us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22
+ @with_cache us u_prev begin
23
+ @with_cache us u_next begin
24
+ # perform tasks with these two `u` vectors
25
+ end
26
+ end
27
+ ```
28
+ """
29
+ mutable struct CachePool{T, THREAD_SAFE}
30
+ pool:: Vector{T}
31
+ alloc:: Function
32
+ lock:: ReentrantLock
33
+ num_alloced:: Int
34
+
35
+ function CachePool (T, alloc:: F ; thread_safe:: Bool = true ) where {F}
36
+ return new {T,Val{thread_safe}} (T[], alloc, ReentrantLock (), 0 )
37
+ end
38
+ end
39
+ const ThreadSafeCachePool{T} = CachePool{T,Val{true }}
40
+ const ThreadUnsafeCachePool{T} = CachePool{T,Val{false }}
41
+
42
+ """
43
+ acquire!(cache::CachePool)
44
+
45
+ Returns a cached element of the cache pool, calling `cache.alloc()` if none
46
+ are available.
47
+ """
48
+ Base. @inline function acquire! (cache:: CachePool{T} , _dummy = nothing ) where {T}
49
+ if isempty (cache. pool)
50
+ cache. num_alloced += 1
51
+ return cache. alloc ():: T
52
+ end
53
+ return pop! (cache. pool)
54
+ end
55
+
56
+ """
57
+ release!(cache::CachePool, val)
58
+
59
+ Returns the value `val` to the cache pool.
60
+ """
61
+ Base. @inline function release! (cache:: CachePool , val, _dummy = nothing )
62
+ push! (cache. pool, val)
63
+ end
64
+
65
+ # Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
66
+ acquire! (cache:: ThreadSafeCachePool ) = @lock cache. lock acquire! (cache, nothing )
67
+ release! (cache:: ThreadSafeCachePool , val) = @lock cache. lock release! (cache, val, nothing )
68
+
69
+ macro with_cache (cache, name, body)
70
+ return quote
71
+ $ (esc (name)) = acquire! ($ (esc (cache)))
72
+ try
73
+ $ (esc (body))
74
+ finally
75
+ release! ($ (esc (cache)), $ (esc (name)))
76
+ end
77
+ end
78
+ end
79
+
80
+
81
+ struct IndependentlyLinearizedSolutionChunksCache{T,S}
82
+ t_chunks:: ThreadUnsafeCachePool{Vector{T}}
83
+ u_chunks:: ThreadUnsafeCachePool{Matrix{S}}
84
+ time_masks:: ThreadUnsafeCachePool{BitMatrix}
85
+
86
+ function IndependentlyLinearizedSolutionChunksCache {T,S} (num_us:: Int , num_derivatives:: Int , chunk_size:: Int ) where {T,S}
87
+ t_chunks_alloc = () -> Vector {T} (undef, chunk_size)
88
+ u_chunks_alloc = () -> Matrix {S} (undef, num_derivatives+ 1 , chunk_size)
89
+ time_masks_alloc = () -> BitMatrix (undef, num_us, chunk_size)
90
+ return new (
91
+ CachePool (Vector{T}, t_chunks_alloc; thread_safe= false ),
92
+ CachePool (Matrix{S}, u_chunks_alloc; thread_safe= false ),
93
+ CachePool (BitMatrix, time_masks_alloc; thread_safe= false ),
94
+ )
95
+ end
96
+ end
97
+
5
98
"""
6
99
IndependentlyLinearizedSolutionChunks
7
100
8
101
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9
102
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10
103
that the solve will generate.
11
104
"""
12
- mutable struct IndependentlyLinearizedSolutionChunks{T, S}
105
+ mutable struct IndependentlyLinearizedSolutionChunks{T, S, N }
13
106
t_chunks:: Vector{Vector{T}}
14
107
u_chunks:: Vector{Vector{Matrix{S}}}
15
108
time_masks:: Vector{BitMatrix}
16
109
110
+ # Temporary array that gets used by `get_chunks`
111
+ last_chunks:: Vector{Matrix{S}}
112
+
17
113
# Index of next write into the last chunk
18
114
u_offsets:: Vector{Int}
19
115
t_offset:: Int
20
116
117
+ cache:: IndependentlyLinearizedSolutionChunksCache
118
+
21
119
function IndependentlyLinearizedSolutionChunks {T, S} (num_us:: Int , num_derivatives:: Int = 0 ,
22
- chunk_size:: Int = 100 ) where {T, S}
23
- return new ([Vector {T} (undef, chunk_size)],
24
- [[Matrix {S} (undef, num_derivatives+ 1 , chunk_size)] for _ in 1 : num_us],
25
- [BitMatrix (undef, num_us, chunk_size)],
26
- [1 for _ in 1 : num_us],
27
- 1 ,
28
- )
120
+ chunk_size:: Int = 512 ,
121
+ cache:: IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)) where {T, S}
122
+ t_chunks = [acquire! (cache. t_chunks)]
123
+ u_chunks = [[acquire! (cache. u_chunks)] for _ in 1 : num_us]
124
+ time_masks = [acquire! (cache. time_masks)]
125
+ last_chunks = [u_chunks[u_idx][1 ] for u_idx in 1 : num_us]
126
+ u_offsets = [1 for _ in 1 : num_us]
127
+ t_offset = 1
128
+ return new {T,S,num_derivatives} (t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
29
129
end
30
130
end
31
131
@@ -44,14 +144,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
44
144
end
45
145
return length (ilsc. u_chunks)
46
146
end
147
+ num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} ) where {T,S,N} = N
47
148
48
- function num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks )
49
- # If we've been finalized, just return `0` (which means only the primal)
50
- if isempty (ilsc. t_chunks)
51
- return 0
52
- end
53
- return size (first (first (ilsc. u_chunks)), 1 ) - 1
54
- end
55
149
56
150
function Base. isempty (ilsc:: IndependentlyLinearizedSolutionChunks )
57
151
return length (ilsc. t_chunks) == 1 && ilsc. t_offset == 1
@@ -61,24 +155,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
61
155
# Check if we need to allocate new `t` chunk
62
156
chunksize = chunk_size (ilsc)
63
157
if ilsc. t_offset > chunksize
64
- push! (ilsc. t_chunks, Vector {T} (undef, chunksize ))
65
- push! (ilsc. time_masks, BitMatrix (undef, length ( ilsc. u_offsets), chunksize ))
158
+ push! (ilsc. t_chunks, acquire! (ilsc . cache . t_chunks ))
159
+ push! (ilsc. time_masks, acquire! ( ilsc. cache . time_masks ))
66
160
ilsc. t_offset = 1
67
161
end
68
162
69
163
# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
70
164
for (u_idx, u_chunks) in enumerate (ilsc. u_chunks)
71
165
if ilsc. u_offsets[u_idx] > chunksize
72
- push! (u_chunks, Matrix {S} (undef, num_derivatives ( ilsc) + 1 , chunksize ))
166
+ push! (u_chunks, acquire! ( ilsc. cache . u_chunks ))
73
167
ilsc. u_offsets[u_idx] = 1
74
168
end
169
+ ilsc. last_chunks[u_idx] = u_chunks[end ]
75
170
end
76
171
77
172
# return the last chunk for each
78
173
return (
79
174
ilsc. t_chunks[end ],
80
175
ilsc. time_masks[end ],
81
- [u_chunks[ end ] for u_chunks in ilsc. u_chunks] ,
176
+ ilsc. last_chunks ,
82
177
)
83
178
end
84
179
@@ -135,16 +230,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
135
230
ts, time_mask, us = get_chunks (ilsc)
136
231
137
232
# Store into the chunks, gated by `u_mask`
138
- for u_idx in 1 : size (u, 2 )
233
+ @inbounds for u_idx in 1 : size (u, 2 )
139
234
if u_mask[u_idx]
140
235
for deriv_idx in 1 : size (u, 1 )
141
236
us[u_idx][deriv_idx, ilsc. u_offsets[u_idx]] = u[deriv_idx, u_idx]
142
237
end
143
238
ilsc. u_offsets[u_idx] += 1
144
239
end
240
+
241
+ # Update our `time_mask` while we're at it
242
+ time_mask[u_idx, ilsc. t_offset] = u_mask[u_idx]
145
243
end
146
244
ts[ilsc. t_offset] = t
147
- time_mask[:, ilsc. t_offset] .= u_mask
148
245
ilsc. t_offset += 1
149
246
end
150
247
@@ -161,7 +258,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161
258
of the state variables at all timepoints, as well as an efficient `sample!()`
162
259
method that can sample at arbitrary timesteps.
163
260
"""
164
- mutable struct IndependentlyLinearizedSolution{T, S}
261
+ mutable struct IndependentlyLinearizedSolution{T, S, N }
165
262
# All timepoints, shared by all `us`
166
263
ts:: Vector{T}
167
264
@@ -174,32 +271,44 @@ mutable struct IndependentlyLinearizedSolution{T, S}
174
271
175
272
# Temporary object used during construction, will be set to `nothing` at the end.
176
273
ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
274
+ ilsc_cache_pool:: Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177
275
end
178
276
# Helper function to create an ILS wrapped around an in-progress ILSC
179
- function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S} ) where {T,S}
180
- ils = IndependentlyLinearizedSolution (
277
+ function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} , cache_pool = nothing ) where {T,S,N }
278
+ return IndependentlyLinearizedSolution {T,S,N} (
181
279
T[],
182
280
Matrix{S}[],
183
281
BitMatrix (undef, 0 ,0 ),
184
282
ilsc,
283
+ cache_pool,
185
284
)
186
- return ils
187
285
end
188
286
# Automatically create an ILS wrapped around an ILSC from a `prob`
189
- function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 )
287
+ function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 ;
288
+ cache_pool = nothing ,
289
+ chunk_size:: Int = 512 )
190
290
T = eltype (prob. tspan)
291
+ S = eltype (prob. u0)
191
292
U = isnothing (prob. u0) ? Float64 : eltype (prob. u0)
192
- N = isnothing (prob. u0) ? 0 : length (prob. u0)
193
- chunks = IndependentlyLinearizedSolutionChunks {T,U} (N, num_derivatives)
194
- return IndependentlyLinearizedSolution (chunks)
293
+ num_us = isnothing (prob. u0) ? 0 : length (prob. u0)
294
+ if cache_pool === nothing
295
+ cache_pool = CachePool (
296
+ IndependentlyLinearizedSolutionChunksCache{T,S},
297
+ () -> IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size);
298
+ thread_safe = true ,
299
+ )
300
+ end
301
+ cache = acquire! (cache_pool)
302
+ chunks = IndependentlyLinearizedSolutionChunks {T,U} (num_us, num_derivatives, chunk_size, cache)
303
+ return IndependentlyLinearizedSolution (chunks, cache_pool)
195
304
end
196
305
197
- num_derivatives (ils :: IndependentlyLinearizedSolution ) = ! isempty (ils . us) ? size ( first (ils . us), 1 ) : 0
306
+ num_derivatives (:: IndependentlyLinearizedSolution{T,S,N} ) where {T,S,N} = N
198
307
num_us (ils:: IndependentlyLinearizedSolution ) = length (ils. us)
199
308
Base. size (ils:: IndependentlyLinearizedSolution ) = size (ils. time_mask)
200
309
Base. length (ils:: IndependentlyLinearizedSolution ) = length (ils. ts)
201
310
202
- function finish! (ils:: IndependentlyLinearizedSolution )
311
+ function finish! (ils:: IndependentlyLinearizedSolution{T,S} ) where {T,S}
203
312
function trim_chunk (chunks:: Vector , offset)
204
313
chunks = [chunk for chunk in chunks]
205
314
if eltype (chunks) <: AbstractVector
@@ -216,10 +325,52 @@ function finish!(ils::IndependentlyLinearizedSolution)
216
325
end
217
326
218
327
ilsc = ils. ilsc:: IndependentlyLinearizedSolutionChunks
219
- ts = vcat (trim_chunk (ilsc. t_chunks, ilsc. t_offset)... )
220
- time_mask = hcat (trim_chunk (ilsc. time_masks, ilsc. t_offset)... )
221
- us = [hcat (trim_chunk (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])... )
222
- for u_idx in 1 : length (ilsc. u_chunks)]
328
+
329
+ chunk_len (chunk) = size (chunk, ndims (chunk))
330
+ function chunks_len (chunks:: Vector , offset)
331
+ len = 0
332
+ for chunk_idx in 1 : length (chunks)- 1
333
+ len += chunk_len (chunks[chunk_idx])
334
+ end
335
+ return len + offset - 1
336
+ end
337
+
338
+ function copy_chunk! (out:: Vector , in:: Vector , out_offset:: Int , len= chunk_len (in))
339
+ for idx in 1 : len
340
+ out[idx+ out_offset] = in[idx]
341
+ end
342
+ end
343
+ function copy_chunk! (out:: AbstractMatrix , in:: AbstractMatrix , out_offset:: Int , len= chunk_len (in))
344
+ for zdx in 1 : size (in, 1 )
345
+ for idx in 1 : len
346
+ out[zdx, idx+ out_offset] = in[zdx, idx]
347
+ end
348
+ end
349
+ end
350
+
351
+ function collapse_chunks! (out, chunks, offset:: Int )
352
+ write_offset = 0
353
+ for chunk_idx in 1 : (length (chunks)- 1 )
354
+ chunk = chunks[chunk_idx]
355
+ copy_chunk! (out, chunk, write_offset)
356
+ write_offset += chunk_len (chunk)
357
+ end
358
+ copy_chunk! (out, chunks[end ], write_offset, offset- 1 )
359
+ end
360
+
361
+ # Collapse t_chunks
362
+ ts = Vector {T} (undef, chunks_len (ilsc. t_chunks, ilsc. t_offset))
363
+ collapse_chunks! (ts, ilsc. t_chunks, ilsc. t_offset)
364
+
365
+ # Collapse u_chunks
366
+ us = Vector {Matrix{S}} (undef, length (ilsc. u_chunks))
367
+ for u_idx in 1 : length (ilsc. u_chunks)
368
+ us[u_idx] = Matrix {S} (undef, size (ilsc. u_chunks[u_idx][1 ],1 ), chunks_len (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx]))
369
+ collapse_chunks! (us[u_idx], ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])
370
+ end
371
+
372
+ time_mask = BitMatrix (undef, size (ilsc. time_masks[1 ], 1 ), chunks_len (ilsc. time_masks, ilsc. t_offset))
373
+ collapse_chunks! (time_mask, ilsc. time_masks, ilsc. t_offset)
223
374
224
375
# Sanity-check lengths
225
376
if length (ts) != size (time_mask, 2 )
@@ -238,7 +389,21 @@ function finish!(ils::IndependentlyLinearizedSolution)
238
389
throw (ArgumentError (" Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens) )" ))
239
390
end
240
391
241
- # Update our struct, release the `ilsc`
392
+ # Update our struct, release the `ilsc` and its caches
393
+ for t_chunk in ilsc. t_chunks
394
+ release! (ilsc. cache. t_chunks, t_chunk)
395
+ end
396
+ for u_idx in 1 : length (ilsc. u_chunks)
397
+ for u_chunk in ilsc. u_chunks[u_idx]
398
+ release! (ilsc. cache. u_chunks, u_chunk)
399
+ end
400
+ end
401
+ for time_mask in ilsc. time_masks
402
+ release! (ilsc. cache. time_masks, time_mask)
403
+ end
404
+ if ils. ilsc_cache_pool != = nothing
405
+ release! (ils. ilsc_cache_pool, ilsc. cache)
406
+ end
242
407
ils. ilsc = nothing
243
408
ils. ts = ts
244
409
ils. us = us
0 commit comments