|
1 | 1 | module SimpleNonlinearSolve
|
2 | 2 |
|
3 |
| -using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, |
4 |
| - AutoPolyesterForwardDiff |
5 | 3 | using CommonSolve: CommonSolve, solve
|
6 | 4 | using PrecompileTools: @compile_workload, @setup_workload
|
7 | 5 | using Reexport: @reexport
|
8 | 6 | @reexport using SciMLBase # I don't like this but needed to avoid a breaking change
|
9 | 7 | using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
|
10 | 8 |
|
| 9 | +# AD Dependencies |
| 10 | +using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, |
| 11 | + AutoPolyesterForwardDiff |
| 12 | +using DifferentiationInterface: DifferentiationInterface |
| 13 | +# TODO: move these to extensions in a breaking change. These are not even used in the |
| 14 | +# package, but are used to trigger the extension loading in DI.jl |
| 15 | +using FiniteDiff: FiniteDiff |
| 16 | +using ForwardDiff: ForwardDiff |
| 17 | + |
11 | 18 | using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
|
12 | 19 | using NonlinearSolveBase: ImmutableNonlinearProblem
|
13 | 20 |
|
| 21 | +const DI = DifferentiationInterface |
| 22 | + |
14 | 23 | abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
|
15 | 24 |
|
16 | 25 | is_extension_loaded(::Val) = false
|
|
51 | 60 | function solve_adjoint_internal end
|
52 | 61 |
|
53 | 62 | @setup_workload begin
|
54 |
| - @compile_workload begin end |
| 63 | + for T in (Float32, Float64) |
| 64 | + prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) |
| 65 | + prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2)) |
| 66 | + prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2)) |
| 67 | + |
| 68 | + algs = [] |
| 69 | + algs_no_iip = [] |
| 70 | + |
| 71 | + @compile_workload begin |
| 72 | + for alg in algs, prob in (prob_scalar, prob_iip, prob_oop) |
| 73 | + CommonSolve.solve(prob, alg) |
| 74 | + end |
| 75 | + for alg in algs_no_iip |
| 76 | + CommonSolve.solve(prob_scalar, alg) |
| 77 | + end |
| 78 | + end |
| 79 | + end |
55 | 80 | end
|
56 | 81 |
|
57 | 82 | export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
|
|
0 commit comments