Skip to content

Commit d74ddb2

Browse files
Make ForwardDiff a weak dependency extension
- Move ForwardDiff-specific code to PreallocationToolsForwardDiffExt - ForwardDiff is now an optional dependency loaded via package extensions - Default implementations provided when ForwardDiff is not available - Uses ref cells to allow extension to modify behavior at runtime - All ForwardDiff-specific methods for Dual types moved to extension
1 parent bbb20ef commit d74ddb2

File tree

3 files changed

+124
-72
lines changed

3 files changed

+124
-72
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ version = "0.4.29"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
109
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1110

1211
[weakdeps]
12+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1414
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1515

1616
[extensions]
17+
PreallocationToolsForwardDiffExt = "ForwardDiff"
1718
PreallocationToolsReverseDiffExt = "ReverseDiff"
1819
PreallocationToolsSparseConnectivityTracerExt = "SparseConnectivityTracer"
1920

@@ -43,6 +44,7 @@ julia = "1.10"
4344
[extras]
4445
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4546
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
47+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4648
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
4749
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4850
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -59,4 +61,4 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
5961
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6062

6163
[targets]
62-
test = ["Aqua", "ADTypes", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
64+
test = ["Aqua", "ADTypes", "ForwardDiff", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
module PreallocationToolsForwardDiffExt
2+
3+
using PreallocationTools
4+
using ForwardDiff
5+
using ArrayInterface
6+
using Adapt
7+
8+
# Initialize on module load
9+
function __init__()
10+
# Set the dual array creator function
11+
PreallocationTools.DUAL_ARRAY_CREATOR[] = function(u::AbstractArray{T}, siz,
12+
::Type{Val{chunk_size}}) where {T, chunk_size}
13+
ArrayInterface.restructure(u,
14+
zeros(ForwardDiff.Dual{Nothing, T, chunk_size},
15+
siz...))
16+
end
17+
18+
# Set the chunk size function to use ForwardDiff's pickchunksize
19+
PreallocationTools.CHUNK_SIZE_FUNC[] = ForwardDiff.pickchunksize
20+
end
21+
22+
# Define chunksize for ForwardDiff.Dual types
23+
PreallocationTools.chunksize(::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = N
24+
25+
# Define get_tmp methods for ForwardDiff.Dual types
26+
function PreallocationTools.get_tmp(dc::PreallocationTools.FixedSizeDiffCache, u::T) where {T <: ForwardDiff.Dual}
27+
x = reinterpret(T, dc.dual_du)
28+
if PreallocationTools.chunksize(T) === PreallocationTools.chunksize(eltype(dc.dual_du))
29+
x
30+
else
31+
@view x[axes(dc.du)...]
32+
end
33+
end
34+
35+
function PreallocationTools.get_tmp(dc::PreallocationTools.FixedSizeDiffCache, u::Type{T}) where {T <: ForwardDiff.Dual}
36+
x = reinterpret(T, dc.dual_du)
37+
if PreallocationTools.chunksize(T) === PreallocationTools.chunksize(eltype(dc.dual_du))
38+
x
39+
else
40+
@view x[axes(dc.du)...]
41+
end
42+
end
43+
44+
function PreallocationTools.get_tmp(dc::PreallocationTools.FixedSizeDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
45+
x = reinterpret(T, dc.dual_du)
46+
if PreallocationTools.chunksize(T) === PreallocationTools.chunksize(eltype(dc.dual_du))
47+
x
48+
else
49+
@view x[axes(dc.du)...]
50+
end
51+
end
52+
53+
function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache, u::T) where {T <: ForwardDiff.Dual}
54+
if isbitstype(T)
55+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
56+
if nelem > length(dc.dual_du)
57+
PreallocationTools.enlargediffcache!(dc, nelem)
58+
end
59+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
60+
else
61+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
62+
end
63+
end
64+
65+
function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual}
66+
if isbitstype(T)
67+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
68+
if nelem > length(dc.dual_du)
69+
PreallocationTools.enlargediffcache!(dc, nelem)
70+
end
71+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
72+
else
73+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
74+
end
75+
end
76+
77+
function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
78+
if isbitstype(T)
79+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
80+
if nelem > length(dc.dual_du)
81+
PreallocationTools.enlargediffcache!(dc, nelem)
82+
end
83+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
84+
else
85+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
86+
end
87+
end
88+
89+
end

src/PreallocationTools.jl

Lines changed: 31 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module PreallocationTools
22

3-
using ForwardDiff, ArrayInterface, Adapt
3+
using ArrayInterface, Adapt
44
using PrecompileTools
55

66
struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
@@ -9,11 +9,17 @@ struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
99
any_du::Vector{Any}
1010
end
1111

12+
# Mutable container to hold dual array creator that can be updated by extension
13+
const DUAL_ARRAY_CREATOR = Ref{Union{Nothing,Function}}(nothing)
14+
1215
function FixedSizeDiffCache(u::AbstractArray{T}, siz,
1316
::Type{Val{chunk_size}}) where {T, chunk_size}
14-
x = ArrayInterface.restructure(u,
15-
zeros(ForwardDiff.Dual{Nothing, T, chunk_size},
16-
siz...))
17+
# Try to use ForwardDiff if available, otherwise fallback
18+
x = if !isnothing(DUAL_ARRAY_CREATOR[])
19+
DUAL_ARRAY_CREATOR[](u, siz, Val{chunk_size})
20+
else
21+
similar(u, siz...)
22+
end
1723
xany = Any[]
1824
FixedSizeDiffCache(deepcopy(u), x, xany)
1925
end
@@ -25,8 +31,18 @@ Builds a `FixedSizeDiffCache` object that stores both a version of the cache for
2531
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2632
forward-mode automatic differentiation.
2733
"""
34+
# Default chunk size calculation without ForwardDiff
35+
default_chunk_size(n) = min(n, 12)
36+
37+
# Mutable container to hold chunk size function that can be updated by extension
38+
const CHUNK_SIZE_FUNC = Ref{Function}(default_chunk_size)
39+
40+
function forwarddiff_compat_chunk_size(n)
41+
CHUNK_SIZE_FUNC[](n)
42+
end
43+
2844
function FixedSizeDiffCache(u::AbstractArray,
29-
::Type{Val{N}} = Val{ForwardDiff.pickchunksize(length(u))}) where {
45+
::Type{Val{N}} = Val{forwarddiff_compat_chunk_size(length(u))}) where {
3046
N,
3147
}
3248
FixedSizeDiffCache(u, size(u), Val{N})
@@ -36,34 +52,10 @@ function FixedSizeDiffCache(u::AbstractArray, N::Integer)
3652
FixedSizeDiffCache(u, size(u), Val{N})
3753
end
3854

39-
chunksize(::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = N
55+
# Generic fallback for chunksize
56+
chunksize(::Type{T}) where {T} = 0
4057

41-
function get_tmp(dc::FixedSizeDiffCache, u::T) where {T <: ForwardDiff.Dual}
42-
x = reinterpret(T, dc.dual_du)
43-
if chunksize(T) === chunksize(eltype(dc.dual_du))
44-
x
45-
else
46-
@view x[axes(dc.du)...]
47-
end
48-
end
49-
50-
function get_tmp(dc::FixedSizeDiffCache, u::Type{T}) where {T <: ForwardDiff.Dual}
51-
x = reinterpret(T, dc.dual_du)
52-
if chunksize(T) === chunksize(eltype(dc.dual_du))
53-
x
54-
else
55-
@view x[axes(dc.du)...]
56-
end
57-
end
58-
59-
function get_tmp(dc::FixedSizeDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
60-
x = reinterpret(T, dc.dual_du)
61-
if chunksize(T) === chunksize(eltype(dc.dual_du))
62-
x
63-
else
64-
@view x[axes(dc.du)...]
65-
end
66-
end
58+
# ForwardDiff-specific methods moved to extension
6759

6860
function get_tmp(dc::FixedSizeDiffCache, u::Union{Number, AbstractArray})
6961
if promote_type(eltype(dc.du), eltype(u)) <: eltype(dc.du)
@@ -103,19 +95,19 @@ function DiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
10395
end
10496

10597
"""
106-
`DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); levels::Int = 1)`
98+
`DiffCache(u::AbstractArray, N::Int = forwarddiff_compat_chunk_size(length(u)); levels::Int = 1)`
10799
`DiffCache(u::AbstractArray; N::AbstractArray{<:Int})`
108100
109101
Builds a `DiffCache` object that stores both a version of the cache for `u`
110102
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
111103
forward-mode automatic differentiation via
112-
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
104+
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) (when available).
113105
Supports nested AD via keyword `levels` or specifying an array of chunk sizes.
114106
115107
The `DiffCache` also supports sparsity detection via
116108
[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
117109
"""
118-
function DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
110+
function DiffCache(u::AbstractArray, N::Int = forwarddiff_compat_chunk_size(length(u));
119111
levels::Int = 1)
120112
DiffCache(u, size(u), N * ones(Int, levels))
121113
end
@@ -133,41 +125,7 @@ const dualcache = DiffCache
133125
134126
Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`.
135127
"""
136-
function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual}
137-
if isbitstype(T)
138-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
139-
if nelem > length(dc.dual_du)
140-
enlargediffcache!(dc, nelem)
141-
end
142-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
143-
else
144-
_restructure(dc.du, zeros(T, size(dc.du)))
145-
end
146-
end
147-
148-
function get_tmp(dc::DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual}
149-
if isbitstype(T)
150-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
151-
if nelem > length(dc.dual_du)
152-
enlargediffcache!(dc, nelem)
153-
end
154-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
155-
else
156-
_restructure(dc.du, zeros(T, size(dc.du)))
157-
end
158-
end
159-
160-
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
161-
if isbitstype(T)
162-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
163-
if nelem > length(dc.dual_du)
164-
enlargediffcache!(dc, nelem)
165-
end
166-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
167-
else
168-
_restructure(dc.du, zeros(T, size(dc.du)))
169-
end
170-
end
128+
# ForwardDiff-specific methods moved to extension
171129

172130
function get_tmp(dc::DiffCache, u::Union{Number, AbstractArray})
173131
if promote_type(eltype(dc.du), eltype(u)) <: eltype(dc.du)
@@ -291,6 +249,9 @@ Base.getindex(b::GeneralLazyBufferCache, u::T) where {T} = get_tmp(b, u)
291249
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
292250
export get_tmp
293251

252+
# Export internal functions for extension use (but not public API)
253+
# These are needed by the ForwardDiff extension
254+
294255
@setup_workload begin
295256
@compile_workload begin
296257
# Basic precompilation

0 commit comments

Comments
 (0)