Skip to content

Commit d96cb0a

Browse files
feat: make Pre automatically recurse into subexpressions
Close #4095
1 parent b418b08 commit d96cb0a

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

lib/ModelingToolkitBase/src/systems/callbacks.jl

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -217,35 +217,42 @@ unPre(x::Symbolics.Arr) = unPre(unwrap(x))
217217
unPre(x::SymbolicT) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x
218218
distribute_shift_into_operator(::Pre) = false
219219

220-
function (p::Pre)(x)
221-
iw = Symbolics.iswrapped(x)
222-
x = unwrap(x)
223-
# non-symbolic values don't change
224-
SU.isconst(x) && return x
225-
if symbolic_type(x) == NotSymbolic()
226-
return x
227-
end
228-
# differential variables are default-toterm-ed
229-
if iscall(x) && operation(x) isa Differential
230-
x = default_toterm(x)
231-
end
232-
# don't double wrap
233-
iscall(x) && operation(x) isa Pre && return x
234-
result = if iscall(x) && operation(x) === getindex
235-
# instead of `Pre(x[1])` create `Pre(x)[1]`
236-
# which allows parameter indexing to handle this case automatically.
237-
arr = arguments(x)[1]
238-
p(arr)[arguments(x)[2:end]...]
239-
else
240-
term(p, x; type = symtype(x), shape = SU.shape(x))
241-
end
242-
# the result should be a parameter
243-
result = toparam(result)
244-
if iw
245-
result = wrap(result)
220+
(p::Pre)(x::Num) = Num(p(unwrap(x)))
221+
(p::Pre)(x::Symbolics.Arr{T, N}) where {T, N} = Symbolics.Arr{T, N}(p(unwrap(x)))
222+
(p::Pre)(x::Symbolics.SymStruct{T}) where {T} = Symbolics.SymStruct{T}(p(unwrap(x)))
223+
(p::Pre)(x::Symbolics.CallAndWrap{T}) where {T} = Symbolics.CallAndWrap{T}(p(unwrap(x)))
224+
function (p::Pre)(x::SymbolicT)
225+
iscall(x) || return x
226+
return Moshi.Match.@match x begin
227+
BSImpl.Term(; f) && if f isa Pre end => return x
228+
BSImpl.Term(; f) && if f isa Differential end => begin
229+
return p(default_toterm(x))
230+
end
231+
BSImpl.Term(; f, args, type, shape) && if f === getindex end => begin
232+
arrpre = p(args[1])
233+
Moshi.Match.@match arrpre begin
234+
BSImpl.Term(; f = f2) && if f2 isa Pre end => begin
235+
newargs = SArgsT((x,))
236+
return toparam(BSImpl.Term{VartypeT}(p, newargs; type, shape))
237+
end
238+
_ => begin
239+
newargs = copy(parent(args))
240+
newargs[1] = arrpre
241+
return toparam(BSImpl.Term{VartypeT}(f, newargs; type, shape))
242+
end
243+
end
244+
end
245+
BSImpl.Term(; f, type, shape) && if f isa SymbolicT && !SU.is_function_symbolic(f) end => begin
246+
return toparam(BSImpl.Term{VartypeT}(p, SArgsT((x,)); type, shape))
247+
end
248+
_ => begin
249+
op = operation(x)
250+
args = map(p, arguments(x))
251+
return toparam(maketerm(SymbolicT, op, args, nothing; type = symtype(x)))
252+
end
246253
end
247-
return result
248254
end
255+
(::Pre)(x) = x
249256
haspre(eq::Equation) = haspre(eq.lhs) || haspre(eq.rhs)
250257
haspre(O) = recursive_hasoperator(Pre, O)
251258

lib/ModelingToolkitBase/test/symbolic_events.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,4 +1824,11 @@ if !@isdefined(ModelingToolkit)
18241824
@mtkcompile sys = System(eqs, t, [X], [p, Kᵢ, Kₐ, K]; discrete_events)
18251825
@test_nowarn ODEProblem(sys, [X => 1, p => 1, Kᵢ => 1, Kₐ => 2], (0.0, 1.0))
18261826
end
1827+
1828+
@testset "Issue:4095: `Pre` recurses into expressions" begin
1829+
@variables x(t)
1830+
@parameters p (f::Function)(..)
1831+
@discretes d(t)
1832+
@test isequal(Pre(2x^2 + 3sin(f(x)) - ifelse(p < 0, d, d + 2) + 2p), 2Pre(x)^2 + 3sin(f(Pre(x))) - ifelse(p < 0, Pre(d), Pre(d) + 2) + 2p)
1833+
end
18271834
end

0 commit comments

Comments
 (0)