11module PreallocationTools
22
3- using ForwardDiff, ArrayInterface, Adapt
3+ using ArrayInterface, Adapt
44using PrecompileTools
55
66struct FixedSizeDiffCache{T <: AbstractArray , S <: AbstractArray }
@@ -9,11 +9,17 @@ struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
99 any_du:: Vector{Any}
1010end
1111
12+ # Mutable container to hold dual array creator that can be updated by extension
13+ const DUAL_ARRAY_CREATOR = Ref {Union{Nothing,Function}} (nothing )
14+
1215function FixedSizeDiffCache (u:: AbstractArray{T} , siz,
1316 :: Type{Val{chunk_size}} ) where {T, chunk_size}
14- x = ArrayInterface. restructure (u,
15- zeros (ForwardDiff. Dual{Nothing, T, chunk_size},
16- siz... ))
17+ # Try to use ForwardDiff if available, otherwise fallback
18+ x = if ! isnothing (DUAL_ARRAY_CREATOR[])
19+ DUAL_ARRAY_CREATOR[](u, siz, Val{chunk_size})
20+ else
21+ similar (u, siz... )
22+ end
1723 xany = Any[]
1824 FixedSizeDiffCache (deepcopy (u), x, xany)
1925end
@@ -25,8 +31,18 @@ Builds a `FixedSizeDiffCache` object that stores both a version of the cache for
2531and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2632forward-mode automatic differentiation.
2733"""
34+ # Default chunk size calculation without ForwardDiff
35+ default_chunk_size (n) = min (n, 12 )
36+
37+ # Mutable container to hold chunk size function that can be updated by extension
38+ const CHUNK_SIZE_FUNC = Ref {Function} (default_chunk_size)
39+
40+ function forwarddiff_compat_chunk_size (n)
41+ CHUNK_SIZE_FUNC[](n)
42+ end
43+
2844function FixedSizeDiffCache (u:: AbstractArray ,
29- :: Type{Val{N}} = Val{ForwardDiff . pickchunksize (length (u))}) where {
45+ :: Type{Val{N}} = Val{forwarddiff_compat_chunk_size (length (u))}) where {
3046 N,
3147}
3248 FixedSizeDiffCache (u, size (u), Val{N})
@@ -36,34 +52,10 @@ function FixedSizeDiffCache(u::AbstractArray, N::Integer)
3652 FixedSizeDiffCache (u, size (u), Val{N})
3753end
3854
39- chunksize (:: Type{ForwardDiff.Dual{T, V, N}} ) where {T, V, N} = N
55+ # Generic fallback for chunksize
56+ chunksize (:: Type{T} ) where {T} = 0
4057
41- function get_tmp (dc:: FixedSizeDiffCache , u:: T ) where {T <: ForwardDiff.Dual }
42- x = reinterpret (T, dc. dual_du)
43- if chunksize (T) === chunksize (eltype (dc. dual_du))
44- x
45- else
46- @view x[axes (dc. du)... ]
47- end
48- end
49-
50- function get_tmp (dc:: FixedSizeDiffCache , u:: Type{T} ) where {T <: ForwardDiff.Dual }
51- x = reinterpret (T, dc. dual_du)
52- if chunksize (T) === chunksize (eltype (dc. dual_du))
53- x
54- else
55- @view x[axes (dc. du)... ]
56- end
57- end
58-
59- function get_tmp (dc:: FixedSizeDiffCache , u:: AbstractArray{T} ) where {T <: ForwardDiff.Dual }
60- x = reinterpret (T, dc. dual_du)
61- if chunksize (T) === chunksize (eltype (dc. dual_du))
62- x
63- else
64- @view x[axes (dc. du)... ]
65- end
66- end
58+ # ForwardDiff-specific methods moved to extension
6759
6860function get_tmp (dc:: FixedSizeDiffCache , u:: Union{Number, AbstractArray} )
6961 if promote_type (eltype (dc. du), eltype (u)) <: eltype (dc. du)
@@ -103,19 +95,19 @@ function DiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
10395end
10496
10597"""
106- `DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize (length(u)); levels::Int = 1)`
98+ `DiffCache(u::AbstractArray, N::Int = forwarddiff_compat_chunk_size (length(u)); levels::Int = 1)`
10799`DiffCache(u::AbstractArray; N::AbstractArray{<:Int})`
108100
109101Builds a `DiffCache` object that stores both a version of the cache for `u`
110102and for the `Dual` version of `u`, allowing use of pre-cached vectors with
111103forward-mode automatic differentiation via
112- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
104+ [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) (when available) .
113105Supports nested AD via keyword `levels` or specifying an array of chunk sizes.
114106
115107The `DiffCache` also supports sparsity detection via
116108[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
117109"""
118- function DiffCache (u:: AbstractArray , N:: Int = ForwardDiff . pickchunksize (length (u));
110+ function DiffCache (u:: AbstractArray , N:: Int = forwarddiff_compat_chunk_size (length (u));
119111 levels:: Int = 1 )
120112 DiffCache (u, size (u), N * ones (Int, levels))
121113end
@@ -133,41 +125,7 @@ const dualcache = DiffCache
133125
134126Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`.
135127"""
136- function get_tmp (dc:: DiffCache , u:: T ) where {T <: ForwardDiff.Dual }
137- if isbitstype (T)
138- nelem = div (sizeof (T), sizeof (eltype (dc. dual_du))) * length (dc. du)
139- if nelem > length (dc. dual_du)
140- enlargediffcache! (dc, nelem)
141- end
142- _restructure (dc. du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
143- else
144- _restructure (dc. du, zeros (T, size (dc. du)))
145- end
146- end
147-
148- function get_tmp (dc:: DiffCache , :: Type{T} ) where {T <: ForwardDiff.Dual }
149- if isbitstype (T)
150- nelem = div (sizeof (T), sizeof (eltype (dc. dual_du))) * length (dc. du)
151- if nelem > length (dc. dual_du)
152- enlargediffcache! (dc, nelem)
153- end
154- _restructure (dc. du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
155- else
156- _restructure (dc. du, zeros (T, size (dc. du)))
157- end
158- end
159-
160- function get_tmp (dc:: DiffCache , u:: AbstractArray{T} ) where {T <: ForwardDiff.Dual }
161- if isbitstype (T)
162- nelem = div (sizeof (T), sizeof (eltype (dc. dual_du))) * length (dc. du)
163- if nelem > length (dc. dual_du)
164- enlargediffcache! (dc, nelem)
165- end
166- _restructure (dc. du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
167- else
168- _restructure (dc. du, zeros (T, size (dc. du)))
169- end
170- end
128+ # ForwardDiff-specific methods moved to extension
171129
172130function get_tmp (dc:: DiffCache , u:: Union{Number, AbstractArray} )
173131 if promote_type (eltype (dc. du), eltype (u)) <: eltype (dc. du)
@@ -291,6 +249,9 @@ Base.getindex(b::GeneralLazyBufferCache, u::T) where {T} = get_tmp(b, u)
291249export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
292250export get_tmp
293251
252+ # Export internal functions for extension use (but not public API)
253+ # These are needed by the ForwardDiff extension
254+
294255@setup_workload begin
295256 @compile_workload begin
296257 # Basic precompilation
0 commit comments