Skip to content

Commit 5772197

Browse files
committed
More ergonomic initialization for shifted variables
1 parent 1c0c328 commit 5772197

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

src/discretedomain.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct Shift <: Operator
2727
steps::Int
2828
Shift(t, steps = 1) = new(value(t), steps)
2929
end
30+
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
3031
function (D::Shift)(x, allow_zero = false)
3132
!allow_zero && D.steps == 0 && return x
3233
Term{symtype(x)}(D, Any[x])

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,5 @@ function fast_substitute(expr, pair::Pair)
858858
symtype(expr);
859859
metadata = metadata(expr))
860860
end
861+
862+
normalize_to_differential(s) = s

src/variables.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ isoutput(x) = isvarkind(VariableOutput, x)
3636
isirreducible(x) = isvarkind(VariableIrreducible, x)
3737
state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64
3838

39+
function default_toterm(x)
40+
if istree(x) && (op = operation(x)) isa Operator
41+
if !(op isa Differential)
42+
x = normalize_to_differential(op)(arguments(x)...)
43+
end
44+
Symbolics.diff2term(x)
45+
else
46+
x
47+
end
48+
end
49+
3950
"""
4051
$(SIGNATURES)
4152
@@ -44,7 +55,7 @@ and creates the array of values in the correct order with default values when
4455
applicable.
4556
"""
4657
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
47-
toterm = Symbolics.diff2term, promotetoconcrete = nothing,
58+
toterm = default_toterm, promotetoconcrete = nothing,
4859
tofloat = true, use_union = false)
4960
varlist = collect(map(unwrap, varlist))
5061

test/clock.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ eqs = [yd ~ Sample(t, dt)(y)
113113
]
114114
@named sys = ODESystem(eqs)
115115
ss = structural_simplify(sys)
116-
prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, 1.0), [kp => 1.0; z => 0.0; D(z) => 0.0])
116+
prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, 1.0),
117+
[kp => 1.0; z => 0.0; z(k + 1) => 0.0])
117118
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
118119
# For all inputs in parameters, just initialize them to 0.0, and then set them
119120
# in the callback.

0 commit comments

Comments
 (0)