Skip to content

Commit aa4bdc1

Browse files
committed
add weakdeps, fix AD extensions
1 parent ac5c542 commit aa4bdc1

File tree

4 files changed

+51
-43
lines changed

4 files changed

+51
-43
lines changed

lib/NonlinearSolveBase/Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,27 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3232
[weakdeps]
3333
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3434
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
35+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3536
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3637
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
3738
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
39+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3840
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3941
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
42+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4043

4144
[extensions]
4245
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"
4346
NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore"
47+
NonlinearSolveBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
4448
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
4549
NonlinearSolveBaseLineSearchExt = "LineSearch"
4650
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
51+
NonlinearSolveBaseReverseDiffExt = "ReverseDiff"
4752
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
4853
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
54+
NonlinearSolveBaseTrackerExt = "Tracker"
55+
4956

5057
[compat]
5158
ADTypes = "1.9"
@@ -82,6 +89,7 @@ SparseMatrixColorings = "0.4.5"
8289
StaticArraysCore = "1.4"
8390
SymbolicIndexingInterface = "0.3.43"
8491
Test = "1.10"
92+
Tracker = "0.2.35"
8593
TimerOutputs = "0.5.23"
8694
julia = "1.10"
8795

lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module NonlinearSolveBaseEnzymeExt
22

33
@static if isempty(VERSION.prerelease)
44
using NonlinearSolveBase
5-
import SciMLBase: value
5+
import SciMLBase: SciMLBase, value
66
using Enzyme
77
import Enzyme: Const
88
using ChainRulesCore

lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,92 @@
11
module NonlinearSolveBaseReverseDiffExt
22

3-
using NonlinearSolveBase
4-
import SciMLBase: value
3+
using NonlinearSolveBase
4+
import SciMLBase: SciMLBase, value
55
import ReverseDiff
66
import ArrayInterface
77

88
# `ReverseDiff.TrackedArray`
9-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
9+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
1010
sensealg::Union{
1111
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
1212
Nothing}, u0::ReverseDiff.TrackedArray,
1313
p::ReverseDiff.TrackedArray, args...; kwargs...)
14-
ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
14+
ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
1515
end
1616

17-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
17+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
1818
sensealg::Union{
1919
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
2020
Nothing}, u0, p::ReverseDiff.TrackedArray,
2121
args...; kwargs...)
22-
ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
22+
ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
2323
end
2424

25-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
25+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
2626
sensealg::Union{
2727
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
2828
Nothing}, u0::ReverseDiff.TrackedArray, p,
2929
args...; kwargs...)
30-
ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
30+
ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
3131
end
3232

3333
# `AbstractArray{<:ReverseDiff.TrackedReal}`
34-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
34+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
3535
sensealg::Union{
3636
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
3737
Nothing},
3838
u0::AbstractArray{<:ReverseDiff.TrackedReal},
3939
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
4040
kwargs...)
41-
SciMLBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0),
41+
NonlinearSolveBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0),
4242
ArrayInterface.aos_to_soa(p), args...;
4343
kwargs...)
4444
end
4545

46-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
46+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
4747
sensealg::Union{
4848
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
4949
Nothing}, u0,
5050
p::AbstractArray{<:ReverseDiff.TrackedReal},
5151
args...; kwargs...)
52-
SciMLBase.solve_up(
52+
NonlinearSolveBase.solve_up(
5353
prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...)
5454
end
5555

56-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
56+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
5757
sensealg::Union{
5858
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
5959
Nothing}, u0::ReverseDiff.TrackedArray,
6060
p::AbstractArray{<:ReverseDiff.TrackedReal},
6161
args...; kwargs...)
62-
SciMLBase.solve_up(
62+
NonlinearSolveBase.solve_up(
6363
prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...)
6464
end
6565

66-
function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem,
67-
sensealg::Union{
68-
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
69-
Nothing},
70-
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
71-
args...; kwargs...)
72-
SciMLBase.solve_up(
73-
prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...)
74-
end
66+
# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem,
67+
# sensealg::Union{
68+
# SciMLBase.AbstractOverloadingSensitivityAlgorithm,
69+
# Nothing},
70+
# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
71+
# args...; kwargs...)
72+
# NonlinearSolveBase.solve_up(
73+
# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...)
74+
# end
7575

76-
function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem,
77-
sensealg::Union{
78-
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
79-
Nothing},
80-
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray,
81-
args...; kwargs...)
82-
SciMLBase.solve_up(
83-
prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...)
84-
end
76+
# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem,
77+
# sensealg::Union{
78+
# SciMLBase.AbstractOverloadingSensitivityAlgorithm,
79+
# Nothing},
80+
# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray,
81+
# args...; kwargs...)
82+
# NonlinearSolveBase.solve_up(
83+
# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...)
84+
# end
8585

8686
# Required becase ReverseDiff.@grad function SciMLBase.solve_up is not supported!
87-
import SciMLBase: solve_up
87+
import NonlinearSolveBase: solve_up
8888
ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...)
89-
out = SciMLBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0),
89+
out = NonlinearSolveBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0),
9090
ReverseDiff.value(p),
9191
SciMLBase.ReverseDiffOriginator(), args...; kwargs...)
9292
function actual_adjoint(_args...)

lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,34 @@
11
module NonlinearSolveBaseTrackerExt
22

33
using NonlinearSolveBase
4-
import SciMLBase: value
4+
import SciMLBase: SciMLBase, value
55
import Tracker
66

7-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
7+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
88
sensealg::Union{
99
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
1010
Nothing}, u0::Tracker.TrackedArray,
1111
p::Tracker.TrackedArray, args...; kwargs...)
12-
Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
12+
Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
1313
end
1414

15-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
15+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
1616
sensealg::Union{
1717
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
1818
Nothing}, u0::Tracker.TrackedArray, p, args...;
1919
kwargs...)
20-
Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
20+
Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
2121
end
2222

23-
function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem,
23+
function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem,
2424
sensealg::Union{
2525
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
2626
Nothing}, u0, p::Tracker.TrackedArray, args...;
2727
kwargs...)
28-
Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
28+
Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
2929
end
3030

31-
Tracker.@grad function SciMLBase.solve_up(prob,
31+
Tracker.@grad function NonlinearSolveBase.solve_up(prob,
3232
sensealg::Union{Nothing,
3333
SciMLBase.AbstractOverloadingSensitivityAlgorithm
3434
},

0 commit comments

Comments
 (0)