1
1
module PreallocationTools
2
2
3
- using ForwardDiff, ArrayInterface, Adapt
3
+ using ArrayInterface, Adapt
4
4
using PrecompileTools
5
5
6
6
struct FixedSizeDiffCache{T <: AbstractArray , S <: AbstractArray }
@@ -9,15 +9,18 @@ struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
9
9
any_du:: Vector{Any}
10
10
end
11
11
12
+ # Mutable container to hold dual array creator that can be updated by extension
13
+ dualarraycreator (args... ) = nothing
14
+
12
15
function FixedSizeDiffCache (u:: AbstractArray{T} , siz,
13
16
:: Type{Val{chunk_size}} ) where {T, chunk_size}
14
- x = ArrayInterface. restructure (u,
15
- zeros (ForwardDiff. Dual{Nothing, T, chunk_size},
16
- siz... ))
17
+ x = dualarraycreator (u, siz, Val{chunk_size})
17
18
xany = Any[]
18
19
FixedSizeDiffCache (deepcopy (u), x, xany)
19
20
end
20
21
22
+ forwarddiff_compat_chunk_size (n) = 0
23
+
21
24
"""
22
25
`FixedSizeDiffCache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
23
26
@@ -26,7 +29,7 @@ and for the `Dual` version of `u`, allowing use of pre-cached vectors with
26
29
forward-mode automatic differentiation.
27
30
"""
28
31
function FixedSizeDiffCache (u:: AbstractArray ,
29
- :: Type{Val{N}} = Val{ForwardDiff . pickchunksize (length (u))}) where {
32
+ :: Type{Val{N}} = Val{forwarddiff_compat_chunk_size (length (u))}) where {
30
33
N,
31
34
}
32
35
FixedSizeDiffCache (u, size (u), Val{N})
@@ -36,34 +39,10 @@ function FixedSizeDiffCache(u::AbstractArray, N::Integer)
36
39
FixedSizeDiffCache (u, size (u), Val{N})
37
40
end
38
41
39
- chunksize (:: Type{ForwardDiff.Dual{T, V, N}} ) where {T, V, N} = N
40
-
41
- function get_tmp (dc:: FixedSizeDiffCache , u:: T ) where {T <: ForwardDiff.Dual }
42
- x = reinterpret (T, dc. dual_du)
43
- if chunksize (T) === chunksize (eltype (dc. dual_du))
44
- x
45
- else
46
- @view x[axes (dc. du)... ]
47
- end
48
- end
49
-
50
- function get_tmp (dc:: FixedSizeDiffCache , u:: Type{T} ) where {T <: ForwardDiff.Dual }
51
- x = reinterpret (T, dc. dual_du)
52
- if chunksize (T) === chunksize (eltype (dc. dual_du))
53
- x
54
- else
55
- @view x[axes (dc. du)... ]
56
- end
57
- end
42
+ # Generic fallback for chunksize
43
+ chunksize (:: Type{T} ) where {T} = 0
58
44
59
- function get_tmp (dc:: FixedSizeDiffCache , u:: AbstractArray{T} ) where {T <: ForwardDiff.Dual }
60
- x = reinterpret (T, dc. dual_du)
61
- if chunksize (T) === chunksize (eltype (dc. dual_du))
62
- x
63
- else
64
- @view x[axes (dc. du)... ]
65
- end
66
- end
45
+ # ForwardDiff-specific methods moved to extension
67
46
68
47
function get_tmp (dc:: FixedSizeDiffCache , u:: Union{Number, AbstractArray} )
69
48
if promote_type (eltype (dc. du), eltype (u)) <: eltype (dc. du)
@@ -103,19 +82,19 @@ function DiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
103
82
end
104
83
105
84
"""
106
- `DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize (length(u)); levels::Int = 1)`
85
+ `DiffCache(u::AbstractArray, N::Int = forwarddiff_compat_chunk_size (length(u)); levels::Int = 1)`
107
86
`DiffCache(u::AbstractArray; N::AbstractArray{<:Int})`
108
87
109
88
Builds a `DiffCache` object that stores both a version of the cache for `u`
110
89
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
111
90
forward-mode automatic differentiation via
112
- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
91
+ [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) (when available) .
113
92
Supports nested AD via keyword `levels` or specifying an array of chunk sizes.
114
93
115
94
The `DiffCache` also supports sparsity detection via
116
95
[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
117
96
"""
118
- function DiffCache (u:: AbstractArray , N:: Int = ForwardDiff . pickchunksize (length (u));
97
+ function DiffCache (u:: AbstractArray , N:: Int = forwarddiff_compat_chunk_size (length (u));
119
98
levels:: Int = 1 )
120
99
DiffCache (u, size (u), N * ones (Int, levels))
121
100
end
@@ -133,41 +112,7 @@ const dualcache = DiffCache
133
112
134
113
Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`.
135
114
"""
136
- function get_tmp (dc:: DiffCache , u:: T ) where {T <: ForwardDiff.Dual }
137
- if isbitstype (T)
138
- nelem = div (sizeof (T), sizeof (eltype (dc. dual_du))) * length (dc. du)
139
- if nelem > length (dc. dual_du)
140
- enlargediffcache! (dc, nelem)
141
- end
142
- _restructure (dc. du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
143
- else
144
- _restructure (dc. du, zeros (T, size (dc. du)))
145
- end
146
- end
147
-
148
- function get_tmp (dc:: DiffCache , :: Type{T} ) where {T <: ForwardDiff.Dual }
149
- if isbitstype (T)
150
- nelem = div (sizeof (T), sizeof (eltype (dc. dual_du))) * length (dc. du)
151
- if nelem > length (dc. dual_du)
152
- enlargediffcache! (dc, nelem)
153
- end
154
- _restructure (dc. du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
155
- else
156
- _restructure (dc. du, zeros (T, size (dc. du)))
157
- end
158
- end
159
-
160
- function get_tmp (dc:: DiffCache , u:: AbstractArray{T} ) where {T <: ForwardDiff.Dual }
161
- if isbitstype (T)
162
- nelem = div (sizeof (T), sizeof (eltype (dc. dual_du))) * length (dc. du)
163
- if nelem > length (dc. dual_du)
164
- enlargediffcache! (dc, nelem)
165
- end
166
- _restructure (dc. du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
167
- else
168
- _restructure (dc. du, zeros (T, size (dc. du)))
169
- end
170
- end
115
+ # ForwardDiff-specific methods moved to extension
171
116
172
117
function get_tmp (dc:: DiffCache , u:: Union{Number, AbstractArray} )
173
118
if promote_type (eltype (dc. du), eltype (u)) <: eltype (dc. du)
@@ -291,6 +236,9 @@ Base.getindex(b::GeneralLazyBufferCache, u::T) where {T} = get_tmp(b, u)
291
236
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
292
237
export get_tmp
293
238
239
+ # Export internal functions for extension use (but not public API)
240
+ # These are needed by the ForwardDiff extension
241
+
294
242
@setup_workload begin
295
243
@compile_workload begin
296
244
# Basic precompilation
0 commit comments