Skip to content

Commit 8790d79

Browse files
committed
feat: add the AD workflows
1 parent 8153008 commit 8790d79

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
1010
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
11+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
12+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
13+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1114
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1215
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1316
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -32,6 +35,9 @@ BracketingNonlinearSolve = "1"
3235
ChainRulesCore = "1.24"
3336
CommonSolve = "0.2.4"
3437
DiffEqBase = "6.155"
38+
DifferentiationInterface = "0.5.17"
39+
FiniteDiff = "2.24.0"
40+
ForwardDiff = "0.10.36"
3541
NonlinearSolveBase = "1"
3642
PrecompileTools = "1.2"
3743
Reexport = "1.2"

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
module SimpleNonlinearSolve
22

3-
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
4-
AutoPolyesterForwardDiff
53
using CommonSolve: CommonSolve, solve
64
using PrecompileTools: @compile_workload, @setup_workload
75
using Reexport: @reexport
86
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
97
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
108

9+
# AD Dependencies
10+
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
11+
AutoPolyesterForwardDiff
12+
using DifferentiationInterface: DifferentiationInterface
13+
# TODO: move these to extensions in a breaking change. These are not even used in the
14+
# package, but are used to trigger the extension loading in DI.jl
15+
using FiniteDiff: FiniteDiff
16+
using ForwardDiff: ForwardDiff
17+
1118
using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
1219
using NonlinearSolveBase: ImmutableNonlinearProblem
1320

21+
const DI = DifferentiationInterface
22+
1423
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
1524

1625
is_extension_loaded(::Val) = false
@@ -51,7 +60,23 @@ end
5160
function solve_adjoint_internal end
5261

5362
@setup_workload begin
54-
@compile_workload begin end
63+
for T in (Float32, Float64)
64+
prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
65+
prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2))
66+
prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2))
67+
68+
algs = []
69+
algs_no_iip = []
70+
71+
@compile_workload begin
72+
for alg in algs, prob in (prob_scalar, prob_iip, prob_oop)
73+
CommonSolve.solve(prob, alg)
74+
end
75+
for alg in algs_no_iip
76+
CommonSolve.solve(prob_scalar, alg)
77+
end
78+
end
79+
end
5580
end
5681

5782
export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff

0 commit comments

Comments
 (0)