@@ -501,36 +501,19 @@ function init(prob::AbstractJumpProblem, args...; kwargs...)
501
501
end
502
502
503
503
function init_up (prob:: DEProblem , sensealg, u0, p, args... ; kwargs... )
504
- if haskey (kwargs, :alg ) && (isempty (args) || args[1 ] === nothing )
505
- alg = kwargs[:alg ]
504
+ alg = extract_alg (args, kwargs, prob. kwargs)
505
+ if isnothing (alg) # Default algorithm handling
506
+ _prob = get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); kwargs... )
507
+ init_call (_prob, args... ; kwargs... )
508
+ else
506
509
_prob = get_concrete_problem (prob, isadaptive (alg); kwargs... )
507
510
_alg = prepare_alg (alg, _prob. u0, _prob. p, _prob)
508
511
check_prob_alg_pairing (_prob, alg) # alg for improved inference
509
-
510
- if length (args) <= 1
511
- init_call (_prob, _alg; kwargs... )
512
- else
512
+ if length (args) > 1
513
513
init_call (_prob, _alg, Base. tail (args)... ; kwargs... )
514
- end
515
- elseif haskey (prob. kwargs, :alg ) && (isempty (args) || args[1 ] === nothing )
516
- alg = prob. kwargs[:alg ]
517
- _prob = get_concrete_problem (prob, isadaptive (alg); kwargs... )
518
- _alg = prepare_alg (alg, _prob. u0, _prob. p, _prob)
519
- check_prob_alg_pairing (_prob, alg) # alg for improved inference
520
- if length (args) <= 1
521
- init_call (_prob, _alg; kwargs... )
522
514
else
523
- init_call (_prob, _alg, Base . tail (args) ... ; kwargs... )
515
+ init_call (_prob, _alg; kwargs... )
524
516
end
525
- elseif ! isempty (args) && typeof (args[1 ]) <: DEAlgorithm
526
- alg = args[1 ]
527
- _prob = get_concrete_problem (prob, isadaptive (alg); kwargs... )
528
- check_prob_alg_pairing (_prob, alg)
529
- _alg = prepare_alg (alg, _prob. u0, _prob. p, _prob)
530
- init_call (_prob, _alg, Base. tail (args)... ; kwargs... )
531
- else # Default algorithm handling
532
- _prob = get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); kwargs... )
533
- init_call (_prob, args... ; kwargs... )
534
517
end
535
518
end
536
519
@@ -1005,36 +988,20 @@ end
1005
988
1006
989
function solve_up (prob:: Union{DEProblem, NonlinearProblem} , sensealg, u0, p, args... ;
1007
990
kwargs... )
1008
- if haskey (kwargs, :alg ) && (isempty (args) || args[1 ] === nothing )
1009
- alg = kwargs[:alg ]
991
+ alg = extract_alg (args, kwargs, prob. kwargs)
992
+ if isnothing (alg) # Default algorithm handling
993
+ _prob = get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); u0 = u0,
994
+ p = p, kwargs... )
995
+ solve_call (_prob, args... ; kwargs... )
996
+ else
1010
997
_prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1011
998
_alg = prepare_alg (alg, _prob. u0, _prob. p, _prob)
1012
999
check_prob_alg_pairing (_prob, alg) # use alg for improved inference
1013
- if length (args) <= 1
1014
- solve_call (_prob, _alg; kwargs... )
1015
- else
1000
+ if length (args) > 1
1016
1001
solve_call (_prob, _alg, Base. tail (args)... ; kwargs... )
1017
- end
1018
- elseif haskey (prob. kwargs, :alg ) && (isempty (args) || args[1 ] === nothing )
1019
- alg = prob. kwargs[:alg ]
1020
- _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1021
- _alg = prepare_alg (alg, _prob. u0, _prob. p, _prob)
1022
- check_prob_alg_pairing (_prob, alg) # use alg for improved inference
1023
- if length (args) <= 1
1024
- solve_call (_prob, _alg; kwargs... )
1025
1002
else
1026
- solve_call (_prob, _alg, Base . tail (args) ... ; kwargs... )
1003
+ solve_call (_prob, _alg; kwargs... )
1027
1004
end
1028
- elseif ! isempty (args) && typeof (args[1 ]) <: DEAlgorithm
1029
- alg = args[1 ]
1030
- _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1031
- _alg = prepare_alg (alg, _prob. u0, _prob. p, _prob)
1032
- check_prob_alg_pairing (_prob, alg) # use alg for improved inference
1033
- solve_call (_prob, _alg, Base. tail (args)... ; kwargs... )
1034
- else # Default algorithm handling
1035
- _prob = get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); u0 = u0,
1036
- p = p, kwargs... )
1037
- solve_call (_prob, args... ; kwargs... )
1038
1005
end
1039
1006
end
1040
1007
@@ -1424,16 +1391,12 @@ end
1424
1391
1425
1392
function _solve_adjoint (prob, sensealg, u0, p, originator, args... ; merge_callbacks = true ,
1426
1393
kwargs... )
1427
- alg, _prob = if haskey (kwargs, :alg ) && (isempty (args) || args[1 ] === nothing )
1428
- alg = kwargs[:alg ]
1429
- alg, get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1430
- elseif ! isempty (args) && typeof (args[1 ]) <: DEAlgorithm
1431
- alg = args[1 ]
1432
- alg, get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1433
- else # Default algorithm handling
1434
- alg = isempty (args) ? nothing : args[1 ]
1435
- alg, get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); u0 = u0, p = p,
1436
- kwargs... )
1394
+ alg = extract_alg (args, kwargs, prob. kwargs)
1395
+ if isnothing (alg) # Default algorithm handling
1396
+ _prob = get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); u0 = u0,
1397
+ p = p, kwargs... )
1398
+ else
1399
+ _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1437
1400
end
1438
1401
1439
1402
if has_kwargs (_prob)
@@ -1458,16 +1421,12 @@ end
1458
1421
1459
1422
function _solve_forward (prob, sensealg, u0, p, originator, args... ; merge_callbacks = true ,
1460
1423
kwargs... )
1461
- alg, _prob = if haskey (kwargs, :alg ) && (isempty (args) || args[1 ] === nothing )
1462
- alg = kwargs[:alg ]
1463
- alg, get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1464
- elseif ! isempty (args) && typeof (args[1 ]) <: DEAlgorithm
1465
- alg = args[1 ]
1466
- alg, get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1467
- else # Default algorithm handling
1468
- alg = isempty (args) ? nothing : args[1 ]
1469
- alg, get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); u0 = u0, p = p,
1470
- kwargs... )
1424
+ alg = extract_alg (args, kwargs, prob. kwargs)
1425
+ if isnothing (alg) # Default algorithm handling
1426
+ _prob = get_concrete_problem (prob, ! (typeof (prob) <: DiscreteProblem ); u0 = u0,
1427
+ p = p, kwargs... )
1428
+ else
1429
+ _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
1471
1430
end
1472
1431
1473
1432
if has_kwargs (_prob)
@@ -1490,6 +1449,22 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba
1490
1449
end
1491
1450
end
1492
1451
1452
+ @inline function extract_alg (solve_args, solve_kwargs, prob_kwargs)
1453
+ if isempty (solve_args) || isnothing (solve_args[1 ])
1454
+ if haskey (solve_kwargs, :alg )
1455
+ solve_kwargs[:alg ]
1456
+ elseif haskey (prob_kwargs, :alg )
1457
+ prob_kwargs[:alg ]
1458
+ else
1459
+ nothing
1460
+ end
1461
+ elseif solve_args[1 ] isa DEAlgorithm
1462
+ solve_args[1 ]
1463
+ else
1464
+ nothing
1465
+ end
1466
+ end
1467
+
1493
1468
# ###
1494
1469
# Catch undefined AD overload cases
1495
1470
0 commit comments