Skip to content

Commit 4d9c30e

Browse files
committed
fix: simplenonlinearsolve in cuda kernels
1 parent ecdbc78 commit 4d9c30e

File tree

12 files changed

+219
-30
lines changed

12 files changed

+219
-30
lines changed

.buildkite/pipeline.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,34 @@ steps:
2727
# Don't run Buildkite if the commit message includes the text [skip tests]
2828
if: build.message !~ /\[skip tests\]/
2929

30+
- label: "Julia 1 (SimpleNonlinearSolve)"
31+
plugins:
32+
- JuliaCI/julia#v1:
33+
version: "1"
34+
- JuliaCI/julia-coverage#v1:
35+
codecov: true
36+
dirs:
37+
- src
38+
- ext
39+
command: |
40+
julia --color=yes --code-coverage=user --depwarn=yes --project=lib/SimpleNonlinearSolve -e '
41+
import Pkg;
42+
Pkg.Registry.update();
43+
# Install packages present in subdirectories
44+
dev_pks = Pkg.PackageSpec[];
45+
for path in ("lib/NonlinearSolveBase", "lib/BracketingNonlinearSolve")
46+
push!(dev_pks, Pkg.PackageSpec(; path))
47+
end
48+
Pkg.develop(dev_pks);
49+
Pkg.instantiate();
50+
Pkg.test(; coverage=true)'
51+
agents:
52+
queue: "juliagpu"
53+
cuda: "*"
54+
timeout_in_minutes: 60
55+
# Don't run Buildkite if the commit message includes the text [skip tests]
56+
if: build.message !~ /\[skip tests\]/
57+
3058
env:
3159
GROUP: CUDA
3260
JULIA_PKG_SERVER: "" # it often struggles with our large artifacts

lib/BracketingNonlinearSolve/src/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection,
2828
fl, fr = f(left), f(right)
2929

3030
abstol = NonlinearSolveBase.get_tolerance(
31-
abstol, promote_type(eltype(left), eltype(right)))
31+
left, abstol, promote_type(eltype(left), eltype(right)))
3232

3333
if iszero(fl)
3434
return SciMLBase.build_solution(

lib/BracketingNonlinearSolve/src/brent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1515
ϵ = eps(convert(typeof(fl), 1))
1616

1717
abstol = NonlinearSolveBase.get_tolerance(
18-
abstol, promote_type(eltype(left), eltype(right)))
18+
left, abstol, promote_type(eltype(left), eltype(right)))
1919

2020
if iszero(fl)
2121
return SciMLBase.build_solution(

lib/BracketingNonlinearSolve/src/falsi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
1515
fl, fr = f(left), f(right)
1616

1717
abstol = NonlinearSolveBase.get_tolerance(
18-
abstol, promote_type(eltype(left), eltype(right)))
18+
left, abstol, promote_type(eltype(left), eltype(right)))
1919

2020
if iszero(fl)
2121
return SciMLBase.build_solution(

lib/BracketingNonlinearSolve/src/itp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
6565
fl, fr = f(left), f(right)
6666

6767
abstol = NonlinearSolveBase.get_tolerance(
68-
abstol, promote_type(eltype(left), eltype(right)))
68+
left, abstol, promote_type(eltype(left), eltype(right)))
6969

7070
if iszero(fl)
7171
return SciMLBase.build_solution(

lib/BracketingNonlinearSolve/src/ridder.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
1414
fl, fr = f(left), f(right)
1515

1616
abstol = NonlinearSolveBase.get_tolerance(
17-
abstol, promote_type(eltype(left), eltype(right)))
17+
left, abstol, promote_type(eltype(left), eltype(right)))
1818

1919
if iszero(fl)
2020
return SciMLBase.build_solution(

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ function SciMLBase.init(
3434
du, u, mode::AbstractNonlinearTerminationMode, saved_value_prototype...;
3535
abstol = nothing, reltol = nothing, kwargs...)
3636
T = promote_type(eltype(du), eltype(u))
37-
abstol = get_tolerance(abstol, T)
38-
reltol = get_tolerance(reltol, T)
37+
abstol = get_tolerance(u, abstol, T)
38+
reltol = get_tolerance(u, reltol, T)
3939
TT = typeof(abstol)
4040

4141
u_unaliased = mode isa AbstractSafeBestNonlinearTerminationMode ?
@@ -90,8 +90,8 @@ function SciMLBase.reinit!(
9090
cache.u = u_unaliased
9191
cache.retcode = ReturnCode.Default
9292

93-
cache.abstol = get_tolerance(abstol, T)
94-
cache.reltol = get_tolerance(reltol, T)
93+
cache.abstol = get_tolerance(u, abstol, T)
94+
cache.reltol = get_tolerance(u, reltol, T)
9595
cache.nsteps = 0
9696
TT = typeof(cache.abstol)
9797

@@ -274,8 +274,8 @@ end
274274
function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du,
275275
u, tc::AbstractNonlinearTerminationMode, ::Val)
276276
T = promote_type(eltype(du), eltype(u))
277-
abstol = get_tolerance(abstol, T)
278-
reltol = get_tolerance(reltol, T)
277+
abstol = get_tolerance(u, abstol, T)
278+
reltol = get_tolerance(u, reltol, T)
279279
cache = SciMLBase.init(du, u, tc; abstol, reltol)
280280
return abstol, reltol, cache
281281
end

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,35 +37,62 @@ SimpleNonlinearSolveTrackerExt = "Tracker"
3737

3838
[compat]
3939
ADTypes = "1.2"
40+
Accessors = "0.1"
41+
AllocCheck = "0.1.1"
42+
Aqua = "0.8.7"
4043
ArrayInterface = "7.16"
4144
BracketingNonlinearSolve = "1"
45+
CUDA = "5.3"
4246
ChainRulesCore = "1.24"
4347
CommonSolve = "0.2.4"
4448
ConcreteStructs = "0.2.3"
4549
DiffEqBase = "6.155"
4650
DifferentiationInterface = "0.6.1"
51+
Enzyme = "0.13"
52+
ExplicitImports = "1.9"
4753
FastClosures = "0.3.2"
4854
FiniteDiff = "2.24.0"
4955
ForwardDiff = "0.10.36"
5056
InteractiveUtils = "<0.0.1, 1"
51-
LinearAlgebra = "1.10"
5257
LineSearch = "0.1.3"
58+
LinearAlgebra = "1.10"
5359
MaybeInplace = "0.1.4"
60+
NonlinearProblemLibrary = "0.1.2"
5461
NonlinearSolveBase = "1"
62+
Pkg = "1.10"
63+
PolyesterForwardDiff = "0.1"
5564
PrecompileTools = "1.2"
65+
Random = "1.10"
5666
Reexport = "1.2"
5767
ReverseDiff = "1.15"
5868
SciMLBase = "2.50"
69+
SciMLSensitivity = "7.68"
70+
StaticArrays = "1.9"
5971
StaticArraysCore = "1.4.3"
6072
Test = "1.10"
6173
TestItemRunner = "1"
6274
Tracker = "0.2.35"
75+
Zygote = "0.6.70"
6376
julia = "1.10"
6477

6578
[extras]
79+
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
80+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
81+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
82+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
83+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
6684
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
85+
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
86+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
87+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
88+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
90+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
91+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
6792
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6893
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
94+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
95+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6996

7097
[targets]
71-
test = ["InteractiveUtils", "Test", "TestItemRunner"]
98+
test = ["AllocCheck", "Aqua", "CUDA", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]

lib/SimpleNonlinearSolve/src/lbroyden.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,12 @@ end
204204
for i in 1:threshold
205205
static_idx, static_idx_p1 = Val(i - 1), Val(i)
206206
push!(calls, quote
207-
α = ls_cache === nothing ? true : ls_cache(xo, δx)
207+
if ls_cache === nothing
208+
α = true
209+
else
210+
ls_sol = solve!(ls_cache, xo, δx)
211+
α = ls_sol.step_size # Ignores the return code for now
212+
end
208213
x = xo .+ α .* δx
209214
fx = prob.f(x, prob.p)
210215
δf = fx - fo

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module Utils
22

33
using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
44
using ArrayInterface: ArrayInterface
5+
using ConcreteStructs: @concrete
56
using DifferentiationInterface: DifferentiationInterface, Constant
67
using FastClosures: @closure
78
using LinearAlgebra: LinearAlgebra, I, diagind
@@ -116,25 +117,35 @@ restructure(::Number, x::Number) = x
116117
safe_vec(x::AbstractArray) = vec(x)
117118
safe_vec(x::Number) = x
118119

120+
abstract type AbstractJacobianMode end
121+
122+
struct AnalyticJacobian <: AbstractJacobianMode end
123+
@concrete struct DIExtras <: AbstractJacobianMode
124+
prep
125+
end
126+
struct DINoPreparation <: AbstractJacobianMode end
127+
128+
# While we could run prep in other cases, we don't since we need it completely
129+
# non-allocating for running inside GPU kernels
119130
function prepare_jacobian(prob, autodiff, _, x::Number)
120131
if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f)
121-
return nothing
132+
return AnalyticJacobian()
122133
end
123-
return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
134+
# return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
135+
return DINoPreparation()
124136
end
125137
function prepare_jacobian(prob, autodiff, fx, x)
126-
if SciMLBase.has_jac(prob.f)
127-
return nothing
128-
end
138+
SciMLBase.has_jac(prob.f) && return AnalyticJacobian()
129139
if SciMLBase.isinplace(prob.f)
130-
return DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p))
140+
return DIExtras(DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p)))
131141
else
142+
x isa SArray && return DINoPreparation()
132143
return DI.prepare_jacobian(prob.f, autodiff, x, Constant(prob.p))
133144
end
134145
end
135146

136147
function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras)
137-
if extras === nothing
148+
if extras isa AnalyticJacobian
138149
if SciMLBase.has_jac(prob.f)
139150
return prob.f.jac(x, prob.p)
140151
elseif SciMLBase.has_vjp(prob.f)
@@ -143,11 +154,15 @@ function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras)
143154
return prob.f.jvp(one(x), x, prob.p)
144155
end
145156
end
146-
return DI.derivative(prob.f, extras, autodiff, x, Constant(prob.p))
157+
if extras isa DIExtras
158+
return DI.derivative(prob.f, extras.prep, autodiff, x, Constant(prob.p))
159+
else
160+
return DI.derivative(prob.f, autodiff, x, Constant(prob.p))
161+
end
147162
end
148163
function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
149164
if J === nothing
150-
if extras === nothing
165+
if extras isa AnalyticJacobian
151166
if SciMLBase.isinplace(prob.f)
152167
J = similar(fx, length(fx), length(x))
153168
prob.f.jac(J, x, prob.p)
@@ -157,12 +172,17 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
157172
end
158173
end
159174
if SciMLBase.isinplace(prob)
160-
return DI.jacobian(prob.f, fx, extras, autodiff, x, Constant(prob.p))
175+
@assert extras isa DIExtras
176+
return DI.jacobian(prob.f, fx, extras.prep, autodiff, x, Constant(prob.p))
161177
else
162-
return DI.jacobian(prob.f, extras, autodiff, x, Constant(prob.p))
178+
if extras isa DIExtras
179+
return DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
180+
else
181+
return DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
182+
end
163183
end
164184
end
165-
if extras === nothing
185+
if extras isa AnalyticJacobian
166186
if SciMLBase.isinplace(prob)
167187
prob.jac(J, x, prob.p)
168188
return J
@@ -171,9 +191,22 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
171191
end
172192
end
173193
if SciMLBase.isinplace(prob)
174-
DI.jacobian!(prob.f, fx, J, extras, autodiff, x, Constant(prob.p))
194+
@assert extras isa DIExtras
195+
DI.jacobian!(prob.f, fx, J, extras.prep, autodiff, x, Constant(prob.p))
175196
else
176-
DI.jacobian!(prob.f, J, extras, autodiff, x, Constant(prob.p))
197+
if ArrayInterface.can_setindex(J)
198+
if extras isa DIExtras
199+
DI.jacobian!(prob.f, J, extras.prep, autodiff, x, Constant(prob.p))
200+
else
201+
DI.jacobian!(prob.f, J, autodiff, x, Constant(prob.p))
202+
end
203+
else
204+
if extras isa DIExtras
205+
J = DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
206+
else
207+
J = DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
208+
end
209+
end
177210
end
178211
return J
179212
end

0 commit comments

Comments
 (0)