Skip to content

Commit 37c563a

Browse files
prbzrgChrisRackauckas
authored andcommitted
simplify :alg extraction
1 parent 0a77e9f commit 37c563a

File tree

1 file changed

+43
-68
lines changed

1 file changed

+43
-68
lines changed

src/solve.jl

Lines changed: 43 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -501,36 +501,19 @@ function init(prob::AbstractJumpProblem, args...; kwargs...)
501501
end
502502

503503
function init_up(prob::DEProblem, sensealg, u0, p, args...; kwargs...)
504-
if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing)
505-
alg = kwargs[:alg]
504+
alg = extract_alg(args, kwargs, prob.kwargs)
505+
if isnothing(alg) # Default algorithm handling
506+
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); kwargs...)
507+
init_call(_prob, args...; kwargs...)
508+
else
506509
_prob = get_concrete_problem(prob, isadaptive(alg); kwargs...)
507510
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
508511
check_prob_alg_pairing(_prob, alg) # alg for improved inference
509-
510-
if length(args) <= 1
511-
init_call(_prob, _alg; kwargs...)
512-
else
512+
if length(args) > 1
513513
init_call(_prob, _alg, Base.tail(args)...; kwargs...)
514-
end
515-
elseif haskey(prob.kwargs, :alg) && (isempty(args) || args[1] === nothing)
516-
alg = prob.kwargs[:alg]
517-
_prob = get_concrete_problem(prob, isadaptive(alg); kwargs...)
518-
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
519-
check_prob_alg_pairing(_prob, alg) # alg for improved inference
520-
if length(args) <= 1
521-
init_call(_prob, _alg; kwargs...)
522514
else
523-
init_call(_prob, _alg, Base.tail(args)...; kwargs...)
515+
init_call(_prob, _alg; kwargs...)
524516
end
525-
elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm
526-
alg = args[1]
527-
_prob = get_concrete_problem(prob, isadaptive(alg); kwargs...)
528-
check_prob_alg_pairing(_prob, alg)
529-
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
530-
init_call(_prob, _alg, Base.tail(args)...; kwargs...)
531-
else # Default algorithm handling
532-
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); kwargs...)
533-
init_call(_prob, args...; kwargs...)
534517
end
535518
end
536519

@@ -1005,36 +988,20 @@ end
1005988

1006989
function solve_up(prob::Union{DEProblem, NonlinearProblem}, sensealg, u0, p, args...;
1007990
kwargs...)
1008-
if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing)
1009-
alg = kwargs[:alg]
991+
alg = extract_alg(args, kwargs, prob.kwargs)
992+
if isnothing(alg) # Default algorithm handling
993+
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0,
994+
p = p, kwargs...)
995+
solve_call(_prob, args...; kwargs...)
996+
else
1010997
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
1011998
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
1012999
check_prob_alg_pairing(_prob, alg) # use alg for improved inference
1013-
if length(args) <= 1
1014-
solve_call(_prob, _alg; kwargs...)
1015-
else
1000+
if length(args) > 1
10161001
solve_call(_prob, _alg, Base.tail(args)...; kwargs...)
1017-
end
1018-
elseif haskey(prob.kwargs, :alg) && (isempty(args) || args[1] === nothing)
1019-
alg = prob.kwargs[:alg]
1020-
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
1021-
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
1022-
check_prob_alg_pairing(_prob, alg) # use alg for improved inference
1023-
if length(args) <= 1
1024-
solve_call(_prob, _alg; kwargs...)
10251002
else
1026-
solve_call(_prob, _alg, Base.tail(args)...; kwargs...)
1003+
solve_call(_prob, _alg; kwargs...)
10271004
end
1028-
elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm
1029-
alg = args[1]
1030-
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
1031-
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
1032-
check_prob_alg_pairing(_prob, alg) # use alg for improved inference
1033-
solve_call(_prob, _alg, Base.tail(args)...; kwargs...)
1034-
else # Default algorithm handling
1035-
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0,
1036-
p = p, kwargs...)
1037-
solve_call(_prob, args...; kwargs...)
10381005
end
10391006
end
10401007

@@ -1424,16 +1391,12 @@ end
14241391

14251392
function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
14261393
kwargs...)
1427-
alg, _prob = if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing)
1428-
alg = kwargs[:alg]
1429-
alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
1430-
elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm
1431-
alg = args[1]
1432-
alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
1433-
else # Default algorithm handling
1434-
alg = isempty(args) ? nothing : args[1]
1435-
alg, get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p,
1436-
kwargs...)
1394+
alg = extract_alg(args, kwargs, prob.kwargs)
1395+
if isnothing(alg) # Default algorithm handling
1396+
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0,
1397+
p = p, kwargs...)
1398+
else
1399+
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
14371400
end
14381401

14391402
if has_kwargs(_prob)
@@ -1458,16 +1421,12 @@ end
14581421

14591422
function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
14601423
kwargs...)
1461-
alg, _prob = if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing)
1462-
alg = kwargs[:alg]
1463-
alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
1464-
elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm
1465-
alg = args[1]
1466-
alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
1467-
else # Default algorithm handling
1468-
alg = isempty(args) ? nothing : args[1]
1469-
alg, get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p,
1470-
kwargs...)
1424+
alg = extract_alg(args, kwargs, prob.kwargs)
1425+
if isnothing(alg) # Default algorithm handling
1426+
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0,
1427+
p = p, kwargs...)
1428+
else
1429+
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
14711430
end
14721431

14731432
if has_kwargs(_prob)
@@ -1490,6 +1449,22 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba
14901449
end
14911450
end
14921451

1452+
@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs)
1453+
if isempty(solve_args) || isnothing(solve_args[1])
1454+
if haskey(solve_kwargs, :alg)
1455+
solve_kwargs[:alg]
1456+
elseif haskey(prob_kwargs, :alg)
1457+
prob_kwargs[:alg]
1458+
else
1459+
nothing
1460+
end
1461+
elseif solve_args[1] isa DEAlgorithm
1462+
solve_args[1]
1463+
else
1464+
nothing
1465+
end
1466+
end
1467+
14931468
####
14941469
# Catch undefined AD overload cases
14951470

0 commit comments

Comments
 (0)