Skip to content

Commit e41ee04

Browse files
refactor: centralize problem kwargs handling
1 parent 753cae3 commit e41ee04

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

src/problems/odeproblem.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,13 @@ end
6363
check_compatibility && check_compatible_system(ODEProblem, sys)
6464

6565
f, u0, p = process_SciMLProblem(ODEFunction{iip, spec}, sys, u0map, parammap;
66-
t = tspan !== nothing ? tspan[1] : tspan,
67-
check_length, eval_expression, eval_module, kwargs...)
68-
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
69-
70-
kwargs = filter_kwargs(kwargs)
71-
72-
kwargs1 = (;)
73-
if cbs !== nothing
74-
kwargs1 = merge(kwargs1, (callback = cbs,))
75-
end
76-
77-
tstops = SymbolicTstops(sys; eval_expression, eval_module)
78-
if tstops !== nothing
79-
kwargs1 = merge(kwargs1, (; tstops))
80-
end
66+
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
67+
eval_module, check_compatibility, kwargs...)
8168

69+
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
8270
# Call `remake` so it runs initialization if it is trivial
8371
return remake(ODEProblem{iip}(
84-
f, u0, tspan, p, StandardODEProblem(); kwargs1..., kwargs...))
72+
f, u0, tspan, p, StandardODEProblem(); kwargs...))
8573
end
8674

8775
function check_compatible_system(T::Union{Type{ODEFunction}, Type{ODEProblem}}, sys::System)

src/systems/problem_utils.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,26 @@ function SciMLBase.detect_cycles(sys::AbstractSystem, varmap::Dict{Any, Any}, va
986986
return !isempty(cycles)
987987
end
988988

989+
function process_kwargs(sys::System; callback = nothing, eval_expression = false,
990+
eval_module = @__MODULE__, kwargs...)
991+
kwargs = filter_kwargs(kwargs)
992+
kwargs1 = (;)
993+
994+
if is_time_dependent(sys)
995+
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
996+
if cbs !== nothing
997+
kwargs1 = merge(kwargs1, (callback = cbs,))
998+
end
999+
1000+
tstops = SymbolicTstops(sys; eval_expression, eval_module)
1001+
if tstops !== nothing
1002+
kwargs1 = merge(kwargs1, (; tstops))
1003+
end
1004+
end
1005+
1006+
return merge(kwargs1, kwargs)
1007+
end
1008+
9891009
"""
9901010
$(TYPEDSIGNATURES)
9911011

0 commit comments

Comments
 (0)