Skip to content

Commit 2193ef2

Browse files
Merge pull request SciML#3611 from AayushSabharwal/as/iflifting-fixes
fix: fix nested conditions in `IfLifting`
2 parents cdb2cfb + a9ac726 commit 2193ef2

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

src/systems/if_lifting.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ function (cw::CondRewriter)(expr, dep)
111111
# and ELSE branch is true
112112
# similarly for expression being false
113113
return (ifelse(rw_cond, rw_conda, rw_condb),
114-
implies(ctrue, truea) | implies(cfalse, trueb),
115-
implies(ctrue, falsea) | implies(cfalse, falseb))
114+
ctrue & truea | cfalse & trueb,
115+
ctrue & falsea | cfalse & falseb)
116116
elseif operation(expr) == Base.:(!) # NOT of expression
117117
(a,) = arguments(expr)
118118
(rw, ctrue, cfalse) = cw(a, dep)

test/if_lifting.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,43 @@ end
124124
end
125125
@test_nowarn @mtkbuild sys=SimpleAbs() additional_passes=[IfLifting]
126126
end
127+
128+
@testset "Nested conditions are handled properly" begin
129+
@mtkmodel RampModel begin
130+
@variables begin
131+
x(t)
132+
y(t)
133+
end
134+
@parameters begin
135+
start_time = 1.0
136+
duration = 1.0
137+
height = 1.0
138+
end
139+
@equations begin
140+
y ~ ifelse(start_time < t,
141+
ifelse(t < start_time + duration,
142+
(t - start_time) * height / duration, height),
143+
0.0)
144+
D(x) ~ y
145+
end
146+
end
147+
@mtkbuild sys = RampModel()
148+
@mtkbuild sys2=RampModel() additional_passes=[IfLifting]
149+
prob = ODEProblem(sys, [sys.x => 1.0], (0.0, 3.0))
150+
prob2 = ODEProblem(sys2, [sys.x => 1.0], (0.0, 3.0))
151+
sol = solve(prob)
152+
sol2 = solve(prob2)
153+
@test sol(0.99)[1] > 1.0
154+
@test sol2(0.99)[1] == 1.0
155+
# During ramp
156+
# D(x) ~ t - 1
157+
# x ~ t^2 / 2 - t + C, and `x(1) ~ 1` => `C = 3/2`
158+
# x(1.01) ~ 1.01^2 / 2 - 1.01 + 3/2 ~ 1.00005
159+
@test sol2(1.01)[1] 1.00005
160+
@test sol2(2)[1] 1.5
161+
# After ramp
162+
# D(x) ~ 1
163+
# x ~ t + C and `x(2) ~ 3/2` => `C = -1/2`
164+
# x(3) ~ 3 - 1/2
165+
@test sol2(3)[1] 5 / 2
166+
end

0 commit comments

Comments
 (0)