Skip to content

Commit 7b78bab

Browse files
refactor: fix type-instability in shift-related utilities
1 parent 8ad6e85 commit 7b78bab

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

lib/ModelingToolkitBase/src/variables.jl

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -359,29 +359,24 @@ Simplify multiple shifts: Shift(t, k1)(Shift(t, k2)(x)) becomes Shift(t, k1+k2)(
359359
"""
360360
function simplify_shifts(var::SymbolicT)
361361
ModelingToolkitBase.hasshift(var) || return var
362-
return SU.Rewriters.Postwalk(_simplify_shifts)(var)
362+
return SU.Rewriters.Postwalk(_simplify_shifts)(var)::SymbolicT
363363
end
364364

365+
distribute_shift(eq::Equation) = distribute_shift(eq.lhs) ~ distribute_shift(eq.rhs)
366+
distribute_shift(var::Union{Num, Arr}) = distribute_shift(unwrap(var))
365367
"""
366368
Distribute a shift applied to a whole expression or equation.
367369
Shift(t, 1)(x + y) will become Shift(t, 1)(x) + Shift(t, 1)(y).
368370
Only shifts variables whose independent variable is the same t that appears in the Shift (i.e. constants, time-independent parameters, etc. do not get shifted).
369371
"""
370-
function distribute_shift(var)
371-
var = unwrap(var)
372-
var isa Equation && return distribute_shift(var.lhs) ~ distribute_shift(var.rhs)
373-
374-
ModelingToolkitBase.hasshift(var) || return var
375-
shift = operation(var)
376-
shift isa Shift || return var
377-
378-
shift = operation(var)
379-
expr = only(arguments(var))
380-
if expr isa Equation
381-
return distribute_shift(shift(expr.lhs)) ~ distribute_shift(shift(expr.rhs))
372+
function distribute_shift(var::SymbolicT)
373+
Moshi.Match.@match var begin
374+
BSImpl.Term(; f, args) && if f isa Shift end => begin
375+
shiftexpr = _distribute_shift(args[1], f)
376+
return simplify_shifts(shiftexpr)
377+
end
378+
_ => return var
382379
end
383-
shiftexpr = _distribute_shift(expr, shift)
384-
return simplify_shifts(shiftexpr)
385380
end
386381

387382
"""
@@ -391,10 +386,10 @@ Whether `distribute_shift` should distribute shifts into the given operation.
391386
"""
392387
distribute_shift_into_operator(_) = true
393388

394-
function _distribute_shift(expr, shift)
389+
function _distribute_shift(expr::SymbolicT, shift)
395390
if iscall(expr)
396391
op = operation(expr)
397-
distribute_shift_into_operator(op) || return expr
392+
distribute_shift_into_operator(op)::Bool || return expr
398393
args = arguments(expr)
399394

400395
if ModelingToolkitBase.isvariable(expr) && operation(expr) !== getindex &&

0 commit comments

Comments
 (0)