Skip to content

Commit 3af16d7

Browse files
committed
Use autodiff API based on ADTypes instead of symbols
1 parent 22be384 commit 3af16d7

File tree

14 files changed

+55
-48
lines changed

14 files changed

+55
-48
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "1.14.0"
44

55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
910
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -25,6 +26,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
2526
OptimMOIExt = "MathOptInterface"
2627

2728
[compat]
29+
ADTypes = "1.11.0"
2830
Compat = "3.2.0, 3.3.0, 3.4.0, 3.5.0, 3.6.0, 4"
2931
EnumX = "1.0.4"
3032
FillArrays = "0.6.2, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
@@ -44,7 +46,6 @@ Test = "<0.0.1, 1.6"
4446
julia = "1.10"
4547

4648
[extras]
47-
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4849
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4950
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
5051
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
@@ -59,4 +60,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
5960
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6061

6162
[targets]
62-
test = ["Test", "Distributions", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "LineSearches", "NLSolversBase", "PositiveFactorizations", "ReverseDiff", "ADTypes"]
63+
test = ["Test", "Distributions", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "LineSearches", "NLSolversBase", "PositiveFactorizations", "ReverseDiff"]

docs/src/examples/ipnewton_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ using Test #src
7878
@test Optim.converged(res) #src
7979
@test Optim.minimum(res) 0.25 #src
8080

81-
# Like the rest of Optim, you can also use `autodiff=:forward` and just pass in
81+
# Like the rest of Optim, you can also use `autodiff=ADTypes.AutoForwardDiff()` and just pass in
8282
# `fun`.
8383

8484
# If we only want to set lower bounds, use `ux = fill(Inf, 2)`

docs/src/examples/maxlikenlm.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
using Optim, NLSolversBase
2323
using LinearAlgebra: diag
2424
using ForwardDiff
25+
using ADTypes: AutoForwardDiff
2526

2627
#md # !!! tip
2728
#md # Add Optim with the following command at the Julia command prompt:
@@ -152,7 +153,7 @@ end
152153
func = TwiceDifferentiable(
153154
vars -> Log_Likelihood(x, y, vars[1:nvar], vars[nvar+1]),
154155
ones(nvar + 1);
155-
autodiff = :forward,
156+
autodiff = AutoForwardDiff(),
156157
);
157158

158159
# The above statment accepts 4 inputs: the x matrix, the dependent
@@ -163,7 +164,7 @@ func = TwiceDifferentiable(
163164
# the error variance.
164165
#
165166
# The `ones(nvar+1)` are the starting values for the parameters and
166-
# the `autodiff=:forward` command performs forward mode automatic
167+
# the `autodiff=ADTypes.AutoForwardDiff()` command performs forward mode automatic
167168
# differentiation.
168169
#
169170
# The actual optimization of the likelihood function is accomplished

docs/src/user/gradientsandhessians.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ Automatic differentiation techniques are a middle ground between finite differen
1616

1717
Reverse-mode automatic differentiation can be seen as an automatic implementation of the adjoint method mentioned above, and requires a runtime comparable to only one evaluation of ``f``. It is however considerably more complex to implement, requiring to record the execution of the program to then run it backwards, and incurs a larger overhead.
1818

19-
Forward-mode automatic differentiation is supported through the [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) package by providing the `autodiff=:forward` keyword to `optimize`.
20-
More generic automatic differentiation is supported thanks to [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), by setting `autodiff` to any compatible backend object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
21-
For instance, the user can choose `autodiff=AutoReverseDiff()`, `autodiff=AutoEnzyme()`, `autodiff=AutoMooncake()` or `autodiff=AutoZygote()` for a reverse-mode gradient computation, which is generally faster than forward mode on large inputs.
22-
Each of these choices requires loading the corresponding package beforehand.
19+
Generic automatic differentiation is supported thanks to [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), by setting `autodiff` to any compatible backend object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
20+
For instance, forward-mode automatic differentiation through the [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) package by providing the `autodiff=ADTypes.AutoForwardDiff()` keyword to `optimize`.
21+
Additionally, the user can choose `autodiff=AutoReverseDiff()`, `autodiff=AutoEnzyme()`, `autodiff=AutoMooncake()` or `autodiff=AutoZygote()` for a reverse-mode gradient computation, which is generally faster than forward mode on large inputs.
22+
Each of these choices requires loading the `ADTypes` package and the corresponding automatic differentiation package (e.g., `ForwardDiff` or `ReverseDiff`) beforehand.
2323

2424
## Example
2525

@@ -66,14 +66,14 @@ julia> Optim.minimizer(optimize(f, initial_x, BFGS()))
6666
```
6767
Still looks good. Returning to automatic differentiation, let us try both solvers using this
6868
method. We enable [forward mode](https://github.com/JuliaDiff/ForwardDiff.jl) automatic
69-
differentiation by using the `autodiff = :forward` keyword.
69+
differentiation by using the `autodiff = AutoForwardDiff()` keyword.
7070
```jlcon
71-
julia> Optim.minimizer(optimize(f, initial_x, BFGS(); autodiff = :forward))
71+
julia> Optim.minimizer(optimize(f, initial_x, BFGS(); autodiff = AutoForwardDiff()))
7272
2-element Array{Float64,1}:
7373
1.0
7474
1.0
7575
76-
julia> Optim.minimizer(optimize(f, initial_x, Newton(); autodiff = :forward))
76+
julia> Optim.minimizer(optimize(f, initial_x, Newton(); autodiff = AutoForwardDiff()))
7777
2-element Array{Float64,1}:
7878
1.0
7979
1.0

docs/src/user/minimization.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ If we pass `f` alone, Optim will construct an approximate gradient for us using
2626
```jl
2727
optimize(f, x0, LBFGS())
2828
```
29-
For better performance and greater precision, you can pass your own gradient function. If your objective is written in all Julia code with no special calls to external (that is non-Julia) libraries, you can also use automatic differentiation, by using the `autodiff` keyword and setting it to `:forward`:
29+
For better performance and greater precision, you can pass your own gradient function. If your objective is written in all Julia code with no special calls to external (that is non-Julia) libraries, you can also use automatic differentiation, by using the `autodiff` keyword and setting it to `AutoForwardDiff()`:
3030
```julia
31-
optimize(f, x0, LBFGS(); autodiff = :forward)
31+
optimize(f, x0, LBFGS(); autodiff = AutoForwardDiff())
3232
```
3333

3434
For the Rosenbrock example, the analytical gradient can be shown to be:

ext/OptimMOIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ function MOI.optimize!(model::Optimizer{T}) where {T}
335335
inplace = true,
336336
)
337337
else
338-
d = Optim.promote_objtype(method, initial_x, :finite, true, f, g!, h!)
338+
d = Optim.promote_objtype(method, initial_x, Optim.DEFAULT_AD_TYPE, true, f, g!, h!)
339339
options = Optim.Options(; Optim.default_options(method)..., options...)
340340
if nl_constrained || has_bounds
341341
if nl_constrained

src/Optim.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ import NLSolversBase:
4747
# var for NelderMead
4848
import StatsBase: var
4949

50+
import ADTypes
51+
5052
import LinearAlgebra
5153
import LinearAlgebra:
5254
Diagonal,

src/multivariate/optimize/interface.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ fallback_method(f) = NelderMead()
44
fallback_method(f, g!) = LBFGS()
55
fallback_method(f, g!, h!) = Newton()
66

7+
# By default, use central finite difference method
8+
const DEFAULT_AD_TYPE = ADTypes.AutoFiniteDiff(; fdtype = Val(:central))
9+
710
function fallback_method(f::InplaceObjective)
811
if !(f.fdf isa Nothing)
912
if !(f.hv isa Nothing)
@@ -137,7 +140,7 @@ function optimize(
137140
f,
138141
initial_x::AbstractArray;
139142
inplace = true,
140-
autodiff = :finite,
143+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
141144
)
142145
method = fallback_method(f)
143146
d = promote_objtype(method, initial_x, autodiff, inplace, f)
@@ -149,7 +152,7 @@ function optimize(
149152
f,
150153
g,
151154
initial_x::AbstractArray;
152-
autodiff = :finite,
155+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
153156
inplace = true,
154157
)
155158

@@ -166,7 +169,7 @@ function optimize(
166169
h,
167170
initial_x::AbstractArray;
168171
inplace = true,
169-
autodiff = :finite
172+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
170173
)
171174
method = fallback_method(f, g, h)
172175
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
@@ -189,7 +192,7 @@ function optimize(
189192
initial_x::AbstractArray,
190193
options::Options;
191194
inplace = true,
192-
autodiff = :finite,
195+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
193196
)
194197
method = fallback_method(f)
195198
d = promote_objtype(method, initial_x, autodiff, inplace, f)
@@ -201,7 +204,7 @@ function optimize(
201204
initial_x::AbstractArray,
202205
options::Options;
203206
inplace = true,
204-
autodiff = :finite,
207+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
205208
)
206209

207210
method = fallback_method(f, g)
@@ -215,7 +218,7 @@ function optimize(
215218
initial_x::AbstractArray{T},
216219
options::Options;
217220
inplace = true,
218-
autodiff = :finite,
221+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
219222
) where {T}
220223
method = fallback_method(f, g, h)
221224
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
@@ -230,7 +233,7 @@ function optimize(
230233
method::AbstractOptimizer,
231234
options::Options = Options(; default_options(method)...);
232235
inplace = true,
233-
autodiff = :finite,
236+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
234237
)
235238
d = promote_objtype(method, initial_x, autodiff, inplace, f)
236239
optimize(d, initial_x, method, options)
@@ -242,7 +245,7 @@ function optimize(
242245
method::AbstractOptimizer,
243246
options::Options = Options(; default_options(method)...);
244247
inplace = true,
245-
autodiff = :finite,
248+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
246249
)
247250

248251
d = promote_objtype(method, initial_x, autodiff, inplace, f)
@@ -255,7 +258,7 @@ function optimize(
255258
method::AbstractOptimizer,
256259
options::Options = Options(; default_options(method)...);
257260
inplace = true,
258-
autodiff = :finite,
261+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
259262
)
260263
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
261264

@@ -269,7 +272,7 @@ function optimize(
269272
method::AbstractOptimizer,
270273
options::Options = Options(; default_options(method)...);
271274
inplace = true,
272-
autodiff = :finite,
275+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
273276

274277
)
275278
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
@@ -283,7 +286,7 @@ function optimize(
283286
method::SecondOrderOptimizer,
284287
options::Options = Options(; default_options(method)...);
285288
inplace = true,
286-
autodiff = :finite,
289+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
287290
) where {D<:Union{NonDifferentiable,OnceDifferentiable}}
288291
d = promote_objtype(method, initial_x, autodiff, inplace, d)
289292
optimize(d, initial_x, method, options)

src/multivariate/solvers/constrained/fminbox.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ function optimize(
278278
F::Fminbox = Fminbox(),
279279
options::Options = Options();
280280
inplace::Bool=true,
281-
autodiff = :finite,
281+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
282282
)
283283
if f isa NonDifferentiable
284284
f = f.f
@@ -304,7 +304,7 @@ function optimize(
304304
optimize(od, l, u, initial_x, F, options)
305305
end
306306

307-
function optimize(f, l::Number, u::Number, initial_x::AbstractArray; autodiff = :finite)
307+
function optimize(f, l::Number, u::Number, initial_x::AbstractArray; autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE)
308308
T = eltype(initial_x)
309309
optimize(
310310
OnceDifferentiable(f, initial_x, zero(T); autodiff),
@@ -324,7 +324,7 @@ optimize(
324324
mo::AbstractConstrainedOptimizer,
325325
opt::Options = Options();
326326
inplace::Bool=true,
327-
autodiff = :finite,
327+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
328328
) = optimize(
329329
f,
330330
Fill(T(l), size(initial_x)...),
@@ -343,7 +343,7 @@ function optimize(
343343
mo::AbstractConstrainedOptimizer = Fminbox(),
344344
opt::Options = Options();
345345
inplace::Bool=true,
346-
autodiff = :finite,
346+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
347347
)
348348
T = eltype(initial_x)
349349
optimize(f, T.(l), Fill(T(u), size(initial_x)...), initial_x, mo, opt; inplace, autodiff)
@@ -356,7 +356,7 @@ function optimize(
356356
mo::AbstractConstrainedOptimizer=Fminbox(),
357357
opt::Options = Options();
358358
inplace::Bool=true,
359-
autodiff = :finite,
359+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
360360
)
361361
T = eltype(initial_x)
362362
optimize(f, Fill(T(l), size(initial_x)...), T.(u), initial_x, mo, opt; inplace, autodiff)
@@ -369,7 +369,7 @@ function optimize(
369369
initial_x::AbstractArray,
370370
opt::Options;
371371
inplace::Bool=true,
372-
autodiff = :finite,
372+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
373373
)
374374

375375
T = eltype(initial_x)
@@ -393,7 +393,7 @@ function optimize(
393393
initial_x::AbstractArray,
394394
opt::Options;
395395
inplace::Bool=true,
396-
autodiff = :finite,
396+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
397397
)
398398
T = eltype(initial_x)
399399
optimize(f, g, T.(l), Fill(T(u), size(initial_x)...), initial_x, opt; inplace, autodiff)
@@ -407,7 +407,7 @@ function optimize(
407407
initial_x::AbstractArray,
408408
opt::Options;
409409
inplace::Bool=true,
410-
autodiff = :finite,
410+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
411411
)
412412
T= eltype(initial_x)
413413
optimize(f, g, Fill(T(l), size(initial_x)...), T.(u), initial_x, opt, inplace, autodiff)

test/general/objective_types.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
@test_throws ErrorException OnceDifferentiable(x -> x, rand(10); autodiff = :wah)
55

66
for T in (OnceDifferentiable, TwiceDifferentiable)
7-
odad1 = T(x -> 5.0, rand(1); autodiff = :finite)
8-
odad2 = T(x -> 5.0, rand(1); autodiff = :forward)
7+
odad1 = T(x -> 5.0, rand(1); autodiff = AutoFiniteDiff(; fdtype = Val(:central)))
8+
odad2 = T(x -> 5.0, rand(1); autodiff = AutoForwardDiff())
99
odad3 = T(x -> 5.0, rand(1); autodiff = AutoReverseDiff())
1010
Optim.gradient!(odad1, rand(1))
1111
Optim.gradient!(odad2, rand(1))
@@ -17,8 +17,8 @@
1717

1818
for a in (1.0, 5.0)
1919
xa = rand(1)
20-
odad1 = OnceDifferentiable(x -> a * x[1], xa; autodiff = :finite)
21-
odad2 = OnceDifferentiable(x -> a * x[1], xa; autodiff = :forward)
20+
odad1 = OnceDifferentiable(x -> a * x[1], xa; autodiff = AutoFiniteDiff(; fdtype = Val(:central)))
21+
odad2 = OnceDifferentiable(x -> a * x[1], xa; autodiff = AutoForwardDiff())
2222
odad3 = OnceDifferentiable(x -> a * x[1], xa; autodiff = AutoReverseDiff())
2323
Optim.gradient!(odad1, xa)
2424
Optim.gradient!(odad2, xa)
@@ -29,8 +29,8 @@
2929
end
3030
for a in (1.0, 5.0)
3131
xa = rand(1)
32-
odad1 = OnceDifferentiable(x -> a * x[1]^2, xa; autodiff = :finite)
33-
odad2 = OnceDifferentiable(x -> a * x[1]^2, xa; autodiff = :forward)
32+
odad1 = OnceDifferentiable(x -> a * x[1]^2, xa; autodiff = AutoFiniteDiff(; fdtype = Val(:central)))
33+
odad2 = OnceDifferentiable(x -> a * x[1]^2, xa; autodiff = AutoForwardDiff())
3434
odad3 = OnceDifferentiable(x -> a * x[1]^2, xa; autodiff = AutoReverseDiff())
3535
Optim.gradient!(odad1, xa)
3636
Optim.gradient!(odad2, xa)
@@ -40,7 +40,7 @@
4040
@test Optim.gradient(odad3) == 2.0 * a * xa
4141
end
4242
for dtype in (OnceDifferentiable, TwiceDifferentiable)
43-
for autodiff in (:finite, :forward, AutoReverseDiff())
43+
for autodiff in (AutoFiniteDiff(; fdtype = Val(:central)), AutoForwardDiff(), AutoReverseDiff())
4444
differentiable = dtype(x -> sum(x), rand(2); autodiff = autodiff)
4545
Optim.value(differentiable)
4646
Optim.value!(differentiable, rand(2))

0 commit comments

Comments
 (0)