|
1 | 1 | module PreallocationToolsSparseConnectivityTracerExt |
2 | 2 |
|
3 | | -using PreallocationTools |
4 | | -isdefined(Base, :get_extension) ? (import SparseConnectivityTracer) : |
5 | | -(import ..SparseConnectivityTracer) |
| 3 | +using PreallocationTools: PreallocationTools, DiffCache, get_tmp |
| 4 | +using SparseConnectivityTracer: AbstractTracer, Dual |
6 | 5 |
|
7 | | -function PreallocationTools.get_tmp( |
8 | | - dc::DiffCache, u::T) where {T <: SparseConnectivityTracer.Dual} |
9 | | - if isbitstype(T) |
10 | | - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) |
11 | | - if nelem > length(dc.dual_du) |
12 | | - PreallocationTools.enlargediffcache!(dc, nelem) |
13 | | - end |
14 | | - PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) |
15 | | - else |
16 | | - PreallocationTools._restructure(dc.du, zeros(T, size(dc.du))) |
17 | | - end |
| 6 | +function PreallocationTools.get_tmp(dc::DiffCache, u::T) where {T <: Union{AbstractTracer, Dual}} |
| 7 | + return get_tmp(dc, typeof(u)) |
18 | 8 | end |
19 | 9 |
|
20 | | -function PreallocationTools.get_tmp( |
21 | | - dc::DiffCache, ::Type{T}) where {T <: SparseConnectivityTracer.Dual} |
22 | | - if isbitstype(T) |
23 | | - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) |
24 | | - if nelem > length(dc.dual_du) |
25 | | - PreallocationTools.enlargediffcache!(dc, nelem) |
26 | | - end |
27 | | - PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) |
28 | | - else |
29 | | - PreallocationTools._restructure(dc.du, zeros(T, size(dc.du))) |
30 | | - end |
| 10 | +function PreallocationTools.get_tmp(dc::DiffCache, u::AbstractArray{<:T}) where {T <: Union{AbstractTracer, Dual}} |
| 11 | + return get_tmp(dc, eltype(u)) |
31 | 12 | end |
32 | 13 |
|
33 | | -function PreallocationTools.get_tmp( |
34 | | - dc::DiffCache, u::AbstractArray{T}) where {T <: SparseConnectivityTracer.Dual} |
35 | | - if isbitstype(T) |
36 | | - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) |
37 | | - if nelem > length(dc.dual_du) |
38 | | - PreallocationTools.enlargediffcache!(dc, nelem) |
39 | | - end |
40 | | - PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) |
41 | | - else |
42 | | - PreallocationTools._restructure(dc.du, zeros(T, size(dc.du))) |
43 | | - end |
| 14 | +function PreallocationTools.get_tmp(dc::DiffCache, ::Type{T}) where {T <: Union{AbstractTracer, Dual}} |
| 15 | + # We allocate memory here since we assume that sparsity connection happens only |
| 16 | + # once (or maybe a few times). This simplifies the implementation and allows us |
| 17 | + # to save memory in the long run since we do not need to store an additional |
| 18 | + # cache for the sparsity detection that would be used only once but carried |
| 19 | + # around forever. |
| 20 | + return similar(dc.du, T) |
44 | 21 | end |
45 | 22 |
|
46 | 23 | end |
0 commit comments