Skip to content

Commit e67c91d

Browse files
Merge pull request #131 from ChrisRackauckas-Claude/forwarddiff-weakdep
Make ForwardDiff a weak dependency extension
2 parents bbb20ef + bca8334 commit e67c91d

File tree

3 files changed

+106
-72
lines changed

3 files changed

+106
-72
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ version = "0.4.29"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
109
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1110

1211
[weakdeps]
12+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1414
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1515

1616
[extensions]
17+
PreallocationToolsForwardDiffExt = "ForwardDiff"
1718
PreallocationToolsReverseDiffExt = "ReverseDiff"
1819
PreallocationToolsSparseConnectivityTracerExt = "SparseConnectivityTracer"
1920

@@ -43,6 +44,7 @@ julia = "1.10"
4344
[extras]
4445
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4546
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
47+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4648
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
4749
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4850
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -59,4 +61,4 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
5961
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6062

6163
[targets]
62-
test = ["Aqua", "ADTypes", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
64+
test = ["Aqua", "ADTypes", "ForwardDiff", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
module PreallocationToolsForwardDiffExt
2+
3+
using PreallocationTools
4+
using ForwardDiff
5+
using ArrayInterface
6+
using Adapt
7+
8+
function PreallocationTools.dualarraycreator(u::AbstractArray{T}, siz,
9+
::Type{Val{chunk_size}}) where {T, chunk_size}
10+
ArrayInterface.restructure(u,
11+
zeros(ForwardDiff.Dual{Nothing, T, chunk_size},
12+
siz...))
13+
end
14+
15+
PreallocationTools.forwarddiff_compat_chunk_size(x::Int) = ForwardDiff.pickchunksize(x)
16+
17+
# Define chunksize for ForwardDiff.Dual types
18+
PreallocationTools.chunksize(::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = N
19+
20+
# Define get_tmp methods for ForwardDiff.Dual types
21+
function PreallocationTools.get_tmp(dc::PreallocationTools.FixedSizeDiffCache, u::T) where {T <: ForwardDiff.Dual}
22+
x = reinterpret(T, dc.dual_du)
23+
if PreallocationTools.chunksize(T) === PreallocationTools.chunksize(eltype(dc.dual_du))
24+
x
25+
else
26+
@view x[axes(dc.du)...]
27+
end
28+
end
29+
30+
function PreallocationTools.get_tmp(dc::PreallocationTools.FixedSizeDiffCache, u::Type{T}) where {T <: ForwardDiff.Dual}
31+
x = reinterpret(T, dc.dual_du)
32+
if PreallocationTools.chunksize(T) === PreallocationTools.chunksize(eltype(dc.dual_du))
33+
x
34+
else
35+
@view x[axes(dc.du)...]
36+
end
37+
end
38+
39+
function PreallocationTools.get_tmp(dc::PreallocationTools.FixedSizeDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
40+
x = reinterpret(T, dc.dual_du)
41+
if PreallocationTools.chunksize(T) === PreallocationTools.chunksize(eltype(dc.dual_du))
42+
x
43+
else
44+
@view x[axes(dc.du)...]
45+
end
46+
end
47+
48+
function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache, u::T) where {T <: ForwardDiff.Dual}
49+
if isbitstype(T)
50+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
51+
if nelem > length(dc.dual_du)
52+
PreallocationTools.enlargediffcache!(dc, nelem)
53+
end
54+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
55+
else
56+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
57+
end
58+
end
59+
60+
function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual}
61+
if isbitstype(T)
62+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
63+
if nelem > length(dc.dual_du)
64+
PreallocationTools.enlargediffcache!(dc, nelem)
65+
end
66+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
67+
else
68+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
69+
end
70+
end
71+
72+
function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
73+
if isbitstype(T)
74+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
75+
if nelem > length(dc.dual_du)
76+
PreallocationTools.enlargediffcache!(dc, nelem)
77+
end
78+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
79+
else
80+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
81+
end
82+
end
83+
84+
end

src/PreallocationTools.jl

Lines changed: 18 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module PreallocationTools
22

3-
using ForwardDiff, ArrayInterface, Adapt
3+
using ArrayInterface, Adapt
44
using PrecompileTools
55

66
struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
@@ -9,15 +9,18 @@ struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
99
any_du::Vector{Any}
1010
end
1111

12+
# Mutable container to hold dual array creator that can be updated by extension
13+
dualarraycreator(args...) = nothing
14+
1215
function FixedSizeDiffCache(u::AbstractArray{T}, siz,
1316
::Type{Val{chunk_size}}) where {T, chunk_size}
14-
x = ArrayInterface.restructure(u,
15-
zeros(ForwardDiff.Dual{Nothing, T, chunk_size},
16-
siz...))
17+
x = dualarraycreator(u, siz, Val{chunk_size})
1718
xany = Any[]
1819
FixedSizeDiffCache(deepcopy(u), x, xany)
1920
end
2021

22+
forwarddiff_compat_chunk_size(n) = 0
23+
2124
"""
2225
`FixedSizeDiffCache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
2326
@@ -26,7 +29,7 @@ and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2629
forward-mode automatic differentiation.
2730
"""
2831
function FixedSizeDiffCache(u::AbstractArray,
29-
::Type{Val{N}} = Val{ForwardDiff.pickchunksize(length(u))}) where {
32+
::Type{Val{N}} = Val{forwarddiff_compat_chunk_size(length(u))}) where {
3033
N,
3134
}
3235
FixedSizeDiffCache(u, size(u), Val{N})
@@ -36,34 +39,10 @@ function FixedSizeDiffCache(u::AbstractArray, N::Integer)
3639
FixedSizeDiffCache(u, size(u), Val{N})
3740
end
3841

39-
chunksize(::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = N
40-
41-
function get_tmp(dc::FixedSizeDiffCache, u::T) where {T <: ForwardDiff.Dual}
42-
x = reinterpret(T, dc.dual_du)
43-
if chunksize(T) === chunksize(eltype(dc.dual_du))
44-
x
45-
else
46-
@view x[axes(dc.du)...]
47-
end
48-
end
49-
50-
function get_tmp(dc::FixedSizeDiffCache, u::Type{T}) where {T <: ForwardDiff.Dual}
51-
x = reinterpret(T, dc.dual_du)
52-
if chunksize(T) === chunksize(eltype(dc.dual_du))
53-
x
54-
else
55-
@view x[axes(dc.du)...]
56-
end
57-
end
42+
# Generic fallback for chunksize
43+
chunksize(::Type{T}) where {T} = 0
5844

59-
function get_tmp(dc::FixedSizeDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
60-
x = reinterpret(T, dc.dual_du)
61-
if chunksize(T) === chunksize(eltype(dc.dual_du))
62-
x
63-
else
64-
@view x[axes(dc.du)...]
65-
end
66-
end
45+
# ForwardDiff-specific methods moved to extension
6746

6847
function get_tmp(dc::FixedSizeDiffCache, u::Union{Number, AbstractArray})
6948
if promote_type(eltype(dc.du), eltype(u)) <: eltype(dc.du)
@@ -103,19 +82,19 @@ function DiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
10382
end
10483

10584
"""
106-
`DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); levels::Int = 1)`
85+
`DiffCache(u::AbstractArray, N::Int = forwarddiff_compat_chunk_size(length(u)); levels::Int = 1)`
10786
`DiffCache(u::AbstractArray; N::AbstractArray{<:Int})`
10887
10988
Builds a `DiffCache` object that stores both a version of the cache for `u`
11089
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
11190
forward-mode automatic differentiation via
112-
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
91+
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) (when available).
11392
Supports nested AD via keyword `levels` or specifying an array of chunk sizes.
11493
11594
The `DiffCache` also supports sparsity detection via
11695
[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
11796
"""
118-
function DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
97+
function DiffCache(u::AbstractArray, N::Int = forwarddiff_compat_chunk_size(length(u));
11998
levels::Int = 1)
12099
DiffCache(u, size(u), N * ones(Int, levels))
121100
end
@@ -133,41 +112,7 @@ const dualcache = DiffCache
133112
134113
Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`.
135114
"""
136-
function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual}
137-
if isbitstype(T)
138-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
139-
if nelem > length(dc.dual_du)
140-
enlargediffcache!(dc, nelem)
141-
end
142-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
143-
else
144-
_restructure(dc.du, zeros(T, size(dc.du)))
145-
end
146-
end
147-
148-
function get_tmp(dc::DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual}
149-
if isbitstype(T)
150-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
151-
if nelem > length(dc.dual_du)
152-
enlargediffcache!(dc, nelem)
153-
end
154-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
155-
else
156-
_restructure(dc.du, zeros(T, size(dc.du)))
157-
end
158-
end
159-
160-
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
161-
if isbitstype(T)
162-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
163-
if nelem > length(dc.dual_du)
164-
enlargediffcache!(dc, nelem)
165-
end
166-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
167-
else
168-
_restructure(dc.du, zeros(T, size(dc.du)))
169-
end
170-
end
115+
# ForwardDiff-specific methods moved to extension
171116

172117
function get_tmp(dc::DiffCache, u::Union{Number, AbstractArray})
173118
if promote_type(eltype(dc.du), eltype(u)) <: eltype(dc.du)
@@ -291,6 +236,9 @@ Base.getindex(b::GeneralLazyBufferCache, u::T) where {T} = get_tmp(b, u)
291236
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
292237
export get_tmp
293238

239+
# Export internal functions for extension use (but not public API)
240+
# These are needed by the ForwardDiff extension
241+
294242
@setup_workload begin
295243
@compile_workload begin
296244
# Basic precompilation

0 commit comments

Comments
 (0)