Skip to content

Commit 2cb8d3f

Browse files
committed
thread alias through to __solve
1 parent bd4fba9 commit 2cb8d3f

File tree

17 files changed

+69
-41
lines changed

17 files changed

+69
-41
lines changed

ext/NonlinearSolveFixedPointAccelerationExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
88

99
function SciMLBase.__solve(
1010
prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
11-
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
11+
abstol = nothing, maxiters = 1000, alias = SciMLBase.NonlinearAliasSpecifier(alias_u0 = false),
1212
show_trace::Val = Val(false), termination_condition = nothing, kwargs...
1313
)
14+
alias_u0 = alias.alias_u0
1415
NonlinearSolveBase.assert_extension_supported_termination_condition(
1516
termination_condition, alg
1617
)

ext/NonlinearSolveNLSolversExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ const DI = DifferentiationInterface
1313

1414
function SciMLBase.__solve(
1515
prob::NonlinearProblem, alg::NLSolversJL, args...;
16-
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false,
16+
abstol = nothing, reltol = nothing, maxiters = 1000, alias = SciMLBase.NonlinearAliasSpecifier(alias_u0 = false),
1717
termination_condition = nothing, kwargs...
1818
)
19+
alias_u0 = alias.alias_u0
1920
NonlinearSolveBase.assert_extension_supported_termination_condition(
2021
termination_condition, alg
2122
)

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
99

1010
function SciMLBase.__solve(
1111
prob::NonlinearProblem, alg::NLsolveJL, args...;
12-
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
12+
abstol = nothing, maxiters = 1000, alias = SciMLBase.NonlinearAliasSpecifier(alias_u0 = false),
1313
termination_condition = nothing, trace_level = TraceMinimal(),
1414
store_trace::Val = Val(false), show_trace::Val = Val(false), kwargs...
1515
)
16+
alias_u0 = alias.alias_u0
1617
NonlinearSolveBase.assert_extension_supported_termination_condition(
1718
termination_condition, alg
1819
)

ext/NonlinearSolvePETScExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ using SparseArrays: AbstractSparseMatrix
1414
function SciMLBase.__solve(
1515
prob::NonlinearProblem, alg::PETScSNES, args...;
1616
abstol = nothing, reltol = nothing,
17-
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
17+
maxiters = 1000, alias = SciMLBase.NonlinearAliasSpecifier(alias_u0 = false), termination_condition = nothing,
1818
show_trace::Val = Val(false), kwargs...
1919
)
20+
alias_u0 = alias.alias_u0
2021
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
2122
NonlinearSolveBase.assert_extension_supported_termination_condition(
2223
termination_condition, alg; abs_norm_supported = false

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ end
3939

4040
function SciMLBase.__solve(
4141
prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...;
42-
abstol = nothing, reltol = nothing, alias_u0::Bool = false, maxiters = 1000,
42+
abstol = nothing, reltol = nothing, alias = SciMLBase.NonlinearAliasSpecifier(alias_u0 = false), maxiters = 1000,
4343
termination_condition = nothing, show_trace = Val(false), kwargs...
4444
)
45+
alias_u0 = alias.alias_u0
4546
NonlinearSolveBase.assert_extension_supported_termination_condition(
4647
termination_condition, alg
4748
)

ext/NonlinearSolveSpeedMappingExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
88

99
function SciMLBase.__solve(
1010
prob::NonlinearProblem, alg::SpeedMappingJL, args...;
11-
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
11+
abstol = nothing, maxiters = 1000, alias = SciMLBase.NonlinearAliasSpecifier(alias_u0 = false),
1212
maxtime = nothing, store_trace::Val = Val(false),
1313
termination_condition = nothing, kwargs...
1414
)
15+
alias_u0 = alias.alias_u0
1516
NonlinearSolveBase.assert_extension_supported_termination_condition(
1617
termination_condition, alg
1718
)

lib/NonlinearSolveBase/src/polyalg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,10 @@ end
117117
function SciMLBase.__init(
118118
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
119119
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
120-
internalnorm::IN = L2_NORM, alias_u0 = false, verbose = true,
120+
internalnorm::IN = L2_NORM, alias = NonlinearAliasSpecifier(alias_u0 = false), verbose = true,
121121
initializealg = NonlinearSolveDefaultInit(), kwargs...
122122
) where {IN}
123+
alias_u0 = alias.alias_u0
123124
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
124125
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
125126
immutable (checked using `ArrayInterface.ismutable`)."
@@ -135,7 +136,7 @@ function SciMLBase.__init(
135136
map(alg.algs) do solver
136137
SciMLBase.__init(
137138
prob, solver, args...;
138-
stats, maxtime, internalnorm, alias_u0, verbose,
139+
stats, maxtime, internalnorm, alias, verbose,
139140
initializealg = SciMLBase.NoInit(), kwargs...
140141
)
141142
end,

lib/NonlinearSolveBase/src/solve.jl

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,24 @@ function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing,
5050
sensealg = prob.kwargs[:sensealg]
5151
end
5252

53-
if haskey(prob.kwargs, :alias_u0)
54-
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
55-
alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0])
56-
elseif haskey(kwargs, :alias_u0)
57-
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
58-
alias_spec = NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0])
59-
end
60-
61-
if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool
62-
alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias])
53+
alias_spec = if haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier
54+
kwargs[:alias]
55+
elseif haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier
56+
prob.kwargs[:alias]
6357
elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool
64-
alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias])
65-
end
66-
67-
if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier
68-
alias_spec = prob.kwargs[:alias]
69-
elseif haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier
70-
alias_spec = kwargs[:alias]
58+
NonlinearAliasSpecifier(alias = kwargs[:alias])
59+
elseif haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool
60+
NonlinearAliasSpecifier(alias = prob.kwargs[:alias])
61+
elseif haskey(kwargs, :alias_u0)
62+
@warn lazy"The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
63+
NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0])
64+
elseif haskey(prob.kwargs, :alias_u0)
65+
@warn lazy"The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
66+
NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0])
7167
else
72-
alias_spec = NonlinearAliasSpecifier(alias_u0 = false)
68+
NonlinearAliasSpecifier(alias_u0 = false)
7369
end
7470

75-
alias_u0 = alias_spec.alias_u0
76-
7771
u0 = u0 !== nothing ? u0 : prob.u0
7872
p = p !== nothing ? p : prob.p
7973

@@ -83,7 +77,7 @@ function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing,
8377
u0,
8478
p,
8579
args...;
86-
alias_u0 = alias_u0,
80+
alias = alias_spec,
8781
originator = SciMLBase.ChainRulesOriginator(),
8882
kwargs...))
8983
else
@@ -92,7 +86,7 @@ function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing,
9286
u0,
9387
p,
9488
args...;
95-
alias_u0 = alias_u0,
89+
alias = alias_spec,
9690
originator = SciMLBase.ChainRulesOriginator(),
9791
kwargs...)
9892
end
@@ -170,10 +164,28 @@ function init(
170164
sensealg = prob.kwargs[:sensealg]
171165
end
172166

167+
alias_spec = if haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier
168+
kwargs[:alias]
169+
elseif haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier
170+
prob.kwargs[:alias]
171+
elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool
172+
NonlinearAliasSpecifier(alias = kwargs[:alias])
173+
elseif haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool
174+
NonlinearAliasSpecifier(alias = prob.kwargs[:alias])
175+
elseif haskey(kwargs, :alias_u0)
176+
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
177+
NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0])
178+
elseif haskey(prob.kwargs, :alias_u0)
179+
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
180+
NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0])
181+
else
182+
NonlinearAliasSpecifier(alias_u0 = false)
183+
end
184+
173185
u0 = u0 !== nothing ? u0 : prob.u0
174186
p = p !== nothing ? p : prob.p
175187

176-
init_up(prob, sensealg, u0, p, args...; kwargs...)
188+
init_up(prob, sensealg, u0, p, args...; alias = alias_spec, kwargs...)
177189
end
178190

179191
function init_up(prob::AbstractNonlinearProblem,
@@ -375,13 +387,14 @@ end
375387

376388
@generated function __generated_polysolve(
377389
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
378-
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true,
390+
stats = NLStats(0, 0, 0, 0, 0), alias = NonlinearAliasSpecifier(alias_u0 = false), verbose = true,
379391
initializealg = NonlinearSolveDefaultInit(), kwargs...
380392
) where {N}
381393
sol_syms = [gensym("sol") for _ in 1:N]
382394
prob_syms = [gensym("prob") for _ in 1:N]
383395
u_result_syms = [gensym("u_result") for _ in 1:N]
384396
calls = [quote
397+
alias_u0 = alias.alias_u0
385398
current = alg.start_index
386399
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
387400
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \

lib/NonlinearSolveFirstOrder/src/solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,12 @@ NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache,
127127

128128
function SciMLBase.__init(
129129
prob::AbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...;
130-
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000,
130+
stats = NLStats(0, 0, 0, 0, 0), alias = NonlinearSolveBase.NonlinearAliasSpecifier(alias_u0 = false), maxiters = 1000,
131131
abstol = nothing, reltol = nothing, maxtime = nothing,
132132
termination_condition = nothing, internalnorm::IN = L2_NORM,
133133
linsolve_kwargs = (;), initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs...
134134
) where {IN}
135+
alias_u0 = alias.alias_u0
135136
@set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
136137
provided_jvp_autodiff = alg.jvp_autodiff !== nothing
137138
@set! alg.jvp_autodiff = if !provided_jvp_autodiff && alg.autodiff !== nothing &&

lib/NonlinearSolveQuasiNewton/src/solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,13 @@ NonlinearSolveBase.@internal_caches(QuasiNewtonCache,
145145

146146
function SciMLBase.__init(
147147
prob::AbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...;
148-
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxtime = nothing,
148+
stats = NLStats(0, 0, 0, 0, 0), alias = NonlinearSolveBase.NonlinearAliasSpecifier(alias_u0 = false), maxtime = nothing,
149149
maxiters = 1000, abstol = nothing, reltol = nothing,
150150
linsolve_kwargs = (;), termination_condition = nothing,
151151
internalnorm::F = L2_NORM, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(),
152152
kwargs...
153153
) where {F}
154+
alias_u0 = alias.alias_u0
154155
timer = get_timer_output()
155156
@static_timeit timer "cache construction" begin
156157
u = Utils.maybe_unaliased(prob.u0, alias_u0)

0 commit comments

Comments
 (0)