Skip to content

Commit 22f4dd7

Browse files
authored
Use DifferentiationInterface for autodiff, allow ADTypes (#153)
* Start DI integration * Fix bug * Handle constraints * Bump version to 7.9.0 * Bump Julia compat to 1.10 * Min dif * Improve coverage * Add docs * Get rid of DiffResults * Bump version
1 parent e0d5949 commit 22f4dd7

File tree

8 files changed

+138
-333
lines changed

8 files changed

+138
-333
lines changed

Project.toml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
name = "NLSolversBase"
22
uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
3-
version = "7.8.3"
3+
version = "7.9.0"
44

55
[deps]
6-
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
78
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
89
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1011

1112
[compat]
12-
DiffResults = "1.0"
13-
ForwardDiff = "0.10"
13+
ADTypes = "1.11.0"
14+
DifferentiationInterface = "0.6.43"
1415
FiniteDiff = "2.0"
15-
julia = "1.5"
16+
ForwardDiff = "0.10"
17+
julia = "1.10"
1618

1719
[extras]
20+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
1821
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1922
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2023
OptimTestProblems = "cec144fc-5a64-5bc6-99fb-dde8f63e154c"
@@ -24,4 +27,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2427
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2528

2629
[targets]
27-
test = ["ComponentArrays", "LinearAlgebra", "OptimTestProblems", "Random", "RecursiveArrayTools", "SparseArrays", "Test"]
30+
test = ["ADTypes", "ComponentArrays", "LinearAlgebra", "OptimTestProblems", "Random", "RecursiveArrayTools", "SparseArrays", "Test"]

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ There are currently three main types: `NonDifferentiable`, `OnceDifferentiable`,
2424

2525
The words in front of `Differentiable` in the type names (`Non`, `Once`, `Twice`) are not meant to indicate a specific classification of the function as such (a `OnceDifferentiable` might be constructed for an infinitely differentiable function), but signals to an algorithm if the correct functions have been constructed or if automatic differentiation should be used to further differentiate the function.
2626

27+
## Automatic differentiation
28+
29+
Some constructors for `OnceDifferentiable`, `TwiceDifferentiable`, `OnceDifferentiableConstraints` and `TwiceDifferentiableConstraints` accept a positional argument called `autodiff`.
30+
This argument can be either:
31+
32+
- An object subtyping `AbstractADType`, defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl) and supported by [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl).
33+
- A `Symbol` like `:finite` (and variants thereof) or `:forward`, which fall back on `ADTypes.AutoFiniteDiff` and `ADTypes.AutoForwardDiff` respectively.
34+
- A `Bool`, namely `true`, which falls back on `ADTypes.AutoForwardDiff`.
35+
36+
When the positional argument `chunk` is passed, it is used to configure chunk size in `ADTypes.AutoForwardDiff`, but _only_ if `autodiff in (:forward, true)`.
37+
Indeed, if `autodiff isa ADTypes.AutoForwardDiff`, we assume that the user already selected the appropriate chunk size and so `chunk` is ignored.
38+
2739
## Examples
2840
#### Optimization
2941
Say we want to minimize the Hosaki test function

src/NLSolversBase.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ __precompile__(true)
22

33
module NLSolversBase
44

5-
using FiniteDiff, ForwardDiff, DiffResults
5+
using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff
6+
import DifferentiationInterface as DI
7+
using FiniteDiff: FiniteDiff
8+
using ForwardDiff: ForwardDiff
69
import Distributed: clear!
710
export AbstractObjective,
811
NonDifferentiable,
@@ -54,9 +57,24 @@ function finitediff_fdtype(autodiff)
5457
fdtype
5558
end
5659

60+
forwarddiff_chunksize(::Nothing) = nothing
61+
forwarddiff_chunksize(::ForwardDiff.Chunk{C}) where {C} = C
62+
5763
is_finitediff(autodiff) = autodiff (:central, :finite, :finiteforward, :finitecomplex)
5864
is_forwarddiff(autodiff) = autodiff (:forward, :forwarddiff, true)
5965

66+
get_adtype(autodiff::AbstractADType, chunk=nothing) = autodiff
67+
68+
function get_adtype(autodiff::Union{Symbol,Bool}, chunk=nothing)
69+
if is_finitediff(autodiff)
70+
return AutoFiniteDiff(; fdtype=finitediff_fdtype(autodiff)())
71+
elseif is_forwarddiff(autodiff)
72+
return AutoForwardDiff(; chunksize=forwarddiff_chunksize(chunk))
73+
else
74+
error("The autodiff value $autodiff is not supported. Use :finite or :forward.")
75+
end
76+
end
77+
6078
x_of_nans(x, Tf=eltype(x)) = fill!(Tf.(x), Tf(NaN))
6179

6280
include("objective_types/inplace_factory.jl")

src/objective_types/constraints.jl

Lines changed: 39 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -139,27 +139,13 @@ function OnceDifferentiableConstraints(c!, lx::AbstractVector, ux::AbstractVecto
139139
xcache = zeros(T, sizex)
140140
ccache = zeros(T, sizec)
141141

142-
if is_finitediff(autodiff)
143-
ccache2 = similar(ccache)
144-
fdtype = finitediff_fdtype(autodiff)
145-
jacobian_cache = FiniteDiff.JacobianCache(xcache, ccache,ccache2,fdtype)
146-
function jfinite!(J, x)
147-
FiniteDiff.finite_difference_jacobian!(J, c!, x, jacobian_cache)
148-
J
149-
end
150-
return OnceDifferentiableConstraints(c!, jfinite!, bounds)
151-
elseif is_forwarddiff(autodiff)
152-
jac_cfg = ForwardDiff.JacobianConfig(c!, ccache, xcache, chunk)
153-
ForwardDiff.checktag(jac_cfg, c!, xcache)
154-
155-
function jforward!(J, x)
156-
ForwardDiff.jacobian!(J, c!, ccache, x, jac_cfg, Val{false}())
157-
J
158-
end
159-
return OnceDifferentiableConstraints(c!, jforward!, bounds)
160-
else
161-
error("The autodiff value $autodiff is not support. Use :finite or :forward.")
142+
backend = get_adtype(autodiff, chunk)
143+
jac_prep = DI.prepare_jacobian(c!, ccache, backend, xcache)
144+
function j!(_j, _x)
145+
DI.jacobian!(c!, ccache, _j, jac_prep, backend, _x)
146+
return _j
162147
end
148+
return OnceDifferentiableConstraints(c!, j!, bounds)
163149
end
164150

165151

@@ -179,153 +165,55 @@ function TwiceDifferentiableConstraints(c!, lx::AbstractVector, ux::AbstractVect
179165
lc::AbstractVector, uc::AbstractVector,
180166
autodiff::Symbol = :central,
181167
chunk::ForwardDiff.Chunk = checked_chunk(lx))
182-
if is_finitediff(autodiff)
183-
fdtype = finitediff_fdtype(autodiff)
184-
return twicediff_constraints_finite(c!,lx,ux,lc,uc,fdtype,nothing)
185-
elseif is_forwarddiff(autodiff)
186-
return twicediff_constraints_forward(c!,lx,ux,lc,uc,chunk,nothing)
187-
else
188-
error("The autodiff value $autodiff is not support. Use :finite or :forward.")
189-
end
190-
end
191-
192-
function TwiceDifferentiableConstraints(c!, con_jac!,lx::AbstractVector, ux::AbstractVector,
193-
lc::AbstractVector, uc::AbstractVector,
194-
autodiff::Symbol = :central,
195-
chunk::ForwardDiff.Chunk = checked_chunk(lx))
196-
if is_finitediff(autodiff)
197-
fdtype = finitediff_fdtype(autodiff)
198-
return twicediff_constraints_finite(c!,lx,ux,lc,uc,fdtype,con_jac!)
199-
elseif is_forwarddiff(autodiff)
200-
return twicediff_constraints_forward(c!,lx,ux,lc,uc,chunk,con_jac!)
201-
else
202-
error("The autodiff value $autodiff is not support. Use :finite or :forward.")
203-
end
204-
end
205-
206-
207-
208-
function TwiceDifferentiableConstraints(lx::AbstractArray, ux::AbstractArray)
209-
bounds = ConstraintBounds(lx, ux, [], [])
210-
TwiceDifferentiableConstraints(bounds)
211-
end
212-
213-
214-
function twicediff_constraints_forward(c!, lx, ux, lc, uc,chunk,con_jac! = nothing)
215168
bounds = ConstraintBounds(lx, ux, lc, uc)
216169
T = eltype(bounds)
217170
nc = length(lc)
218171
nx = length(lx)
172+
x_example = zeros(T, nx)
173+
λ_example = zeros(T, nc)
219174
ccache = zeros(T, nc)
220-
xcache = zeros(T, nx)
221-
cache_check = Ref{DataType}(Missing) #the datatype Missing, not the singleton
222-
ref_f= Ref{Any}() #cache for intermediate jacobian used in the hessian
223-
cxxcache = zeros(T, nx * nc, nx) #output cache for hessian
224-
h = reshape(cxxcache, (nc, nx, nx)) #reshaped output
225-
hi = [@view h[i,:,:] for i in 1:nc]
226-
#ref_f caches the closure function with its caches. other aproaches include using a Dict, but the
227-
#cost of switching happens just once per optimize call.
228-
229-
if isnothing(con_jac!) #if the jacobian is not provided, generate one
230-
jac_cfg = ForwardDiff.JacobianConfig(c!, ccache, xcache, chunk)
231-
ForwardDiff.checktag(jac_cfg, c!, xcache)
232-
233-
jac! = (J, x) -> begin
234-
ForwardDiff.jacobian!(J, c!, ccache, x, jac_cfg, Val{false}())
235-
J
236-
end
175+
176+
function sum_constraints(_x, _λ)
177+
# TODO: get rid of this allocation with DI.Cache
178+
ccache_righttype = zeros(promote_type(T, eltype(_x)), nc)
179+
c!(ccache_righttype, _x)
180+
return sum(_λ[i] * ccache[i] for i in eachindex(_λ, ccache))
181+
end
237182

238-
con_jac_cached = x -> begin
239-
exists_cache = (cache_check[] == eltype(x))
240-
if exists_cache
241-
f = ref_f[]
242-
return f(x)
243-
else
244-
jcache = zeros(eltype(x), nc)
245-
out_cache = zeros(eltype(x), nc, nx)
246-
cfg_cache = ForwardDiff.JacobianConfig(c!,jcache,x)
247-
f = z->ForwardDiff.jacobian!(out_cache, c!, jcache, z,cfg_cache,Val{false}())
248-
ref_f[] = f
249-
cache_check[]= eltype(x)
250-
return f(x)
251-
end
252-
end
183+
backend = get_adtype(autodiff, chunk)
253184

254-
else
255-
jac! = (J,x) -> con_jac!(J,x)
256-
257-
#here, the cache should also include a JacobianConfig
258-
con_jac_cached = x -> begin
259-
exists_cache = (cache_check[] == eltype(x))
260-
if exists_cache
261-
f = ref_f[]
262-
return f(x)
263-
else
264-
out_cache = zeros(eltype(x), nc, nx)
265-
f = z->jac!(out_cache,x)
266-
ref_f[] = f
267-
cache_check[]= eltype(x)
268-
return f(x)
269-
end
270-
end
185+
186+
jac_prep = DI.prepare_jacobian(c!, ccache, backend, x_example)
187+
function con_jac!(_j, _x)
188+
DI.jacobian!(c!, ccache, _j, jac_prep, backend, _x)
189+
return _j
271190
end
272-
273-
hess_config_cache = ForwardDiff.JacobianConfig(typeof(con_jac_cached),lx)
274-
function con_hess!(hess, x, λ)
275-
ForwardDiff.jacobian!(cxxcache, con_jac_cached, x,hess_config_cache,Val{false}())
276-
for i = 1:nc #hot hessian loop
277-
hess+=λ[i].*hi[i]
278-
end
279-
return hess
191+
192+
hess_prep = DI.prepare_hessian(sum_constraints, backend, x_example, DI.Constant(λ_example))
193+
function con_hess!(_h, _x, _λ)
194+
DI.hessian!(sum_constraints, _h, hess_prep, backend, _x, DI.Constant(_λ))
195+
return _h
280196
end
281197

282-
return TwiceDifferentiableConstraints(c!, jac!, con_hess!, bounds)
198+
return TwiceDifferentiableConstraints(c!, con_jac!, con_hess!, bounds)
283199
end
284200

285-
286-
function twicediff_constraints_finite(c!,lx,ux,lc,uc,fdtype,con_jac! = nothing)
287-
bounds = ConstraintBounds(lx, ux, lc, uc)
288-
T = eltype(bounds)
289-
nx = length(lx)
290-
nc = length(lc)
291-
xcache = zeros(T, nx)
292-
ccache = zeros(T, nc)
201+
function TwiceDifferentiableConstraints(c!, con_jac!,lx::AbstractVector, ux::AbstractVector,
202+
lc::AbstractVector, uc::AbstractVector,
203+
autodiff::Symbol = :central,
204+
chunk::ForwardDiff.Chunk = checked_chunk(lx))
205+
# TODO: is con_jac! still useful? we ignore it here
293206

294-
if isnothing(con_jac!)
295-
jac_ccache = similar(ccache)
296-
jacobian_cache = FiniteDiff.JacobianCache(xcache, ccache,jac_ccache,fdtype)
297-
function jac!(J, x)
298-
FiniteDiff.finite_difference_jacobian!(J, c!, x, jacobian_cache)
299-
J
300-
end
301-
else
302-
jac! = (J,x) -> con_jac!(J,x)
303-
end
304-
cxxcache = zeros(T,nc*nx,nx) # to create cached jacobian
305-
h = reshape(cxxcache, (nc, nx, nx)) #reshaped output
306-
hi = [@view h[i,:,:] for i in 1:nc]
307-
308-
function jac_vec!(J,x) #to evaluate the jacobian of a jacobian, FiniteDiff needs a vector version of that
309-
j_mat = reshape(J,nc,nx)
310-
return jac!(j_mat,x)
311-
return J
312-
end
313-
hess_xcache =similar(xcache)
314-
hess_cxcache =zeros(T,nc*nx) #output of jacobian, as a vector
315-
hess_cxxcache =similar(hess_cxcache)
316-
hess_config_cache = FiniteDiff.JacobianCache(hess_xcache,hess_cxcache,hess_cxxcache,fdtype)
317-
function con_hess!(hess, x, λ)
318-
FiniteDiff.finite_difference_jacobian!(cxxcache, jac_vec!, x,hess_config_cache)
319-
for i = 1:nc
320-
hi = @view h[i,:,:]
321-
hess+=λ[i].*hi
322-
end
323-
return hess
324-
end
325-
return TwiceDifferentiableConstraints(c!, jac!, con_hess!, bounds)
207+
return TwiceDifferentiableConstraints(c!, lx, ux, lc, uc, autodiff, chunk)
326208
end
327209

328210

211+
212+
function TwiceDifferentiableConstraints(lx::AbstractArray, ux::AbstractArray)
213+
bounds = ConstraintBounds(lx, ux, [], [])
214+
TwiceDifferentiableConstraints(bounds)
215+
end
216+
329217
function TwiceDifferentiableConstraints(bounds::ConstraintBounds)
330218
c! = (x, c)->nothing
331219
J! = (x, J)->nothing

0 commit comments

Comments
 (0)