|
1 | 1 | module NonlinearSolveBaseReverseDiffExt |
2 | 2 |
|
3 | | -using NonlinearSolveBase |
4 | | -import SciMLBase: value |
| 3 | +using NonlinearSolveBase |
| 4 | +import SciMLBase: SciMLBase, value |
5 | 5 | import ReverseDiff |
6 | 6 | import ArrayInterface |
7 | 7 |
|
8 | 8 | # `ReverseDiff.TrackedArray` |
9 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, |
| 9 | +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, |
10 | 10 | sensealg::Union{ |
11 | 11 | SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
12 | 12 | Nothing}, u0::ReverseDiff.TrackedArray, |
13 | 13 | p::ReverseDiff.TrackedArray, args...; kwargs...) |
14 | | - ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) |
| 14 | + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) |
15 | 15 | end |
16 | 16 |
|
17 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, |
| 17 | +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, |
18 | 18 | sensealg::Union{ |
19 | 19 | SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
20 | 20 | Nothing}, u0, p::ReverseDiff.TrackedArray, |
21 | 21 | args...; kwargs...) |
22 | | - ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) |
| 22 | + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) |
23 | 23 | end |
24 | 24 |
|
25 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, |
| 25 | +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, |
26 | 26 | sensealg::Union{ |
27 | 27 | SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
28 | 28 | Nothing}, u0::ReverseDiff.TrackedArray, p, |
29 | 29 | args...; kwargs...) |
30 | | - ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) |
| 30 | + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) |
31 | 31 | end |
32 | 32 |
|
33 | 33 | # `AbstractArray{<:ReverseDiff.TrackedReal}` |
34 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, |
| 34 | +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, |
35 | 35 | sensealg::Union{ |
36 | 36 | SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
37 | 37 | Nothing}, |
38 | 38 | u0::AbstractArray{<:ReverseDiff.TrackedReal}, |
39 | 39 | p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; |
40 | 40 | kwargs...) |
41 | | - SciMLBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), |
| 41 | + NonlinearSolveBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), |
42 | 42 | ArrayInterface.aos_to_soa(p), args...; |
43 | 43 | kwargs...) |
44 | 44 | end |
45 | 45 |
|
46 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, |
| 46 | +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, |
47 | 47 | sensealg::Union{ |
48 | 48 | SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
49 | 49 | Nothing}, u0, |
50 | 50 | p::AbstractArray{<:ReverseDiff.TrackedReal}, |
51 | 51 | args...; kwargs...) |
52 | | - SciMLBase.solve_up( |
| 52 | + NonlinearSolveBase.solve_up( |
53 | 53 | prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) |
54 | 54 | end |
55 | 55 |
|
56 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, |
| 56 | +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, |
57 | 57 | sensealg::Union{ |
58 | 58 | SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
59 | 59 | Nothing}, u0::ReverseDiff.TrackedArray, |
60 | 60 | p::AbstractArray{<:ReverseDiff.TrackedReal}, |
61 | 61 | args...; kwargs...) |
62 | | - SciMLBase.solve_up( |
| 62 | + NonlinearSolveBase.solve_up( |
63 | 63 | prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) |
64 | 64 | end |
65 | 65 |
|
66 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, |
67 | | - sensealg::Union{ |
68 | | - SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
69 | | - Nothing}, |
70 | | - u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, |
71 | | - args...; kwargs...) |
72 | | - SciMLBase.solve_up( |
73 | | - prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) |
74 | | -end |
| 66 | +# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, |
| 67 | +# sensealg::Union{ |
| 68 | +# SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
| 69 | +# Nothing}, |
| 70 | +# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, |
| 71 | +# args...; kwargs...) |
| 72 | +# NonlinearSolveBase.solve_up( |
| 73 | +# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) |
| 74 | +# end |
75 | 75 |
|
76 | | -function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, |
77 | | - sensealg::Union{ |
78 | | - SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
79 | | - Nothing}, |
80 | | - u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, |
81 | | - args...; kwargs...) |
82 | | - SciMLBase.solve_up( |
83 | | - prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) |
84 | | -end |
| 76 | +# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, |
| 77 | +# sensealg::Union{ |
| 78 | +# SciMLBase.AbstractOverloadingSensitivityAlgorithm, |
| 79 | +# Nothing}, |
| 80 | +# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, |
| 81 | +# args...; kwargs...) |
| 82 | +# NonlinearSolveBase.solve_up( |
| 83 | +# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) |
| 84 | +# end |
85 | 85 |
|
86 | 86 | # Required becase ReverseDiff.@grad function SciMLBase.solve_up is not supported! |
87 | | -import SciMLBase: solve_up |
| 87 | +import NonlinearSolveBase: solve_up |
88 | 88 | ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...) |
89 | | - out = SciMLBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), |
| 89 | + out = NonlinearSolveBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), |
90 | 90 | ReverseDiff.value(p), |
91 | 91 | SciMLBase.ReverseDiffOriginator(), args...; kwargs...) |
92 | 92 | function actual_adjoint(_args...) |
|
0 commit comments