Skip to content

Commit 46543d7

Browse files
committed
New operation: Division!
1 parent 21df73b commit 46543d7

File tree

5 files changed

+164
-12
lines changed

5 files changed

+164
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ using `DiffEqGPU.jl`.
151151
## Limitations
152152

153153
Currently, as proof-of-concept, `SourceCodeMcCormick` can only handle functions with
154-
addition (+), subtraction (-), multiplication (\*), powers of 2 (^2), natural base
154+
addition (+), subtraction (-), multiplication (\*), division (/), powers of 2 (^2), natural base
155155
exponentials (exp), and minimum/maximum (min/max) expressions. Future work will include
156156
adding other operations found in `McCormick.jl`.
157157

src/interval/rules.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,40 @@ function transform_rule(::IntervalTransform, ::typeof(*), zL, zU, xL, xU, yL, yU
5555
return rl, ru
5656
end
5757

58+
function transform_rule(::IntervalTransform, ::typeof(/), zL, zU, xL, xU, yL, yU)
59+
rl = Equation(zL, IfElse.ifelse(yL > 0.0,
60+
IfElse.ifelse(xL >= 0.0, xL/yU, #y strictly positive
61+
IfElse.ifelse(xU <= 0.0, xL/yL, xL/yL)),
62+
IfElse.ifelse(yU < 0.0, #y strictly negative
63+
IfElse.ifelse(xL >= 0.0, xU/yU,
64+
IfElse.ifelse(xU <= 0.0, xU/yL, xU/yU)),
65+
IfElse.ifelse(yL == 0.0, # y contains 0 and at least yL is 0
66+
IfElse.ifelse(xL >= 0.0, IfElse.ifelse(yU == 0.0, NaN, IfElse.ifelse(xU == 0.0, 0.0, xL/yU)),
67+
IfElse.ifelse(xU <= 0.0, IfElse.ifelse(yU == 0.0, NaN, IfElse.ifelse(xL == 0.0, 0.0, NaN)),
68+
NaN)),
69+
IfElse.ifelse(yU == 0.0, #y contains 0 and yU is 0 but not yL
70+
IfElse.ifelse(xL >= 0.0, IfElse.ifelse(xU == 0.0, 0.0, NaN),
71+
IfElse.ifelse(xU <= 0.0, IfElse.ifelse(xL == 0.0, 0.0, xU/yL),
72+
NaN)),
73+
NaN)))))
74+
ru = Equation(zU, IfElse.ifelse(yL > 0.0,
75+
IfElse.ifelse(xL >= 0.0, xU/yL, #y strictly positive
76+
IfElse.ifelse(xU <= 0.0, xU/yU, xU/yL)),
77+
IfElse.ifelse(yU < 0.0, #y strictly negative
78+
IfElse.ifelse(xL >= 0.0, xL/yL,
79+
IfElse.ifelse(xU <= 0.0, xL/yU, xL/yU)),
80+
IfElse.ifelse(yL == 0.0, # y contains 0 and at least yL is 0
81+
IfElse.ifelse(xL >= 0.0, IfElse.ifelse(yU == 0.0, NaN, IfElse.ifelse(xU == 0.0, 0.0, NaN)),
82+
IfElse.ifelse(xU <= 0.0, IfElse.ifelse(yU == 0.0, NaN, IfElse.ifelse(xL == 0.0, 0.0, xU/yU)),
83+
NaN)),
84+
IfElse.ifelse(yU == 0.0, #y contains 0 and yU is 0 but not yL
85+
IfElse.ifelse(xL >= 0.0, IfElse.ifelse(xU == 0.0, 0.0, xL/yL),
86+
IfElse.ifelse(xU <= 0.0, IfElse.ifelse(xL == 0.0, 0.0, NaN),
87+
NaN)),
88+
NaN)))))
89+
return rl, ru
90+
end
91+
5892

5993
function transform_rule(::IntervalTransform, ::typeof(min), zL, zU, xL, xU, yL, yU)
6094
rl = Equation(zL, min(xL, yL))

src/relaxation/rules.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,51 @@ function transform_rule(::McCormickTransform, ::typeof(*), zL, zU, zcv, zcc, xL,
165165

166166
end
167167

168+
function transform_rule(::McCormickTransform, ::typeof(/), zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL, yU, ycv, ycc)
169+
# For division, we do x*(y^-1). First we make a NaN-checker to see if
170+
# the denominator contains 0
171+
NaN_check = IfElse.ifelse(yL < 0.0, IfElse.ifelse(yU > 0.0, NaN, 1.0), 1.0)
172+
173+
# Next we calculate the inverse of y
174+
yL_inv = inv(yU)
175+
yU_inv = inv(yL)
176+
ycv_inv = IfElse.ifelse(yL > 0.0, IfElse.ifelse(yU <= ycv, 1.0 ./ ycv,
177+
IfElse.ifelse(yU >= ycc, 1.0 ./ ycc, 1.0 ./ yU)),
178+
IfElse.ifelse(yU < 0.0, IfElse.ifelse(yL == yU, mid_expr(ycc, ycv, yL).^(-1),
179+
((yL.^(-1))*(yU - mid_expr(ycc, ycv, yL)) + (yU.^(-1))*(mid_expr(ycc, ycv, yL) - yL))./(yU - yL)),
180+
NaN))
181+
ycc_inv = IfElse.ifelse(yL > 0.0, IfElse.ifelse(yL <= ycv, (yU + yL - ycv)./(yL*yU),
182+
IfElse.ifelse(yL >= ycc, (yU + yL - ycc)./(yL*yU), 1.0 ./ yL)),
183+
IfElse.ifelse(yU < 0.0, mid_expr(ycc, ycv, yU).^(-1),
184+
NaN))
185+
186+
# Now we use the multiplication rules, but replacing each instance of
187+
# y with its inverse.
188+
rcv = Equation(zcv, IfElse.ifelse(xL >= 0.0,
189+
IfElse.ifelse(yL_inv >= 0.0, max(yU_inv*xcv + xU*ycv_inv - xU*yU_inv, yL_inv*xcv + xL*ycv_inv - xL*yL_inv),
190+
IfElse.ifelse(yU_inv <= 0.0, -min((-yU_inv)*xcc + xU*(-ycv_inv) - xU*(-yU_inv), (-yL_inv)*xcc + xU*(-ycv_inv) - xL*(-yL_inv))*NaN_check,
191+
max(yU_inv*xcv + xU*ycv_inv - xU*yU_inv, yL_inv*xcc + xL*ycv_inv - xL*yL_inv)*NaN_check)),
192+
IfElse.ifelse(xU <= 0.0,
193+
IfElse.ifelse(yL_inv >= 0.0, -min(yL_inv*(-xcv) + (-xL)*ycc_inv - (-xL)*yL_inv, yU_inv*(-xcv) + (-xU)*ycc_inv - (-xU)*yU_inv)*NaN_check,
194+
IfElse.ifelse(yU_inv <= 0.0, max(yL_inv*xcc + xL*ycc_inv - xL*yL_inv, yU_inv*xcc + xU*ycc_inv - xU*yU_inv)*NaN_check,
195+
-min(yL_inv*(-xcc) + (-xL)*ycc_inv - (-xL)*yL_inv, yU_inv*(-xcv) + (-xU)*ycc_inv - (-xU)*yU_inv)*NaN_check)),
196+
IfElse.ifelse(yL_inv >= 0.0, max(xU*ycv_inv + yU_inv*xcv - yU_inv*xU, xL*ycc_inv + yL_inv*xcv - yL_inv*xL)*NaN_check,
197+
IfElse.ifelse(yU_inv <= 0.0, -min(xL*(-ycc_inv) + (-yL_inv)*xcc - (-yL_inv)*xL, xU*(-ycv_inv) + (-yU_inv)*xcc - (-yU_inv)*xU)*NaN_check,
198+
max(yU_inv*xcv + xU*ycv_inv - xU*yU_inv, yL_inv*xcc + xL*ycc_inv - xL*yL_inv)*NaN_check)))))
199+
rcc = Equation(zcc, IfElse.ifelse(xL >= 0.0,
200+
IfElse.ifelse(yL_inv >= 0.0, min(yL_inv*xcc + xU*ycc_inv - xU*yL_inv, yU_inv*xcc + xL*ycc_inv - xL*yU_inv)*NaN_check,
201+
IfElse.ifelse(yU_inv <= 0.0, -max((-yL_inv)*xcv + xU*(-ycc_inv) - xU*(-yL_inv), (-yU_inv)*xcv + xL*(-ycc_inv) - xL*(-yU_inv))*NaN_check,
202+
min(yL_inv*xcv + xU*ycc_inv - xU*yL_inv, yU_inv*xcc + xL*ycc_inv - xL*yU_inv)*NaN_check)),
203+
IfElse.ifelse(xU <= 0.0,
204+
IfElse.ifelse(yL_inv >= 0.0, -max(yU_inv*(-xcc) + (-xL)*ycv_inv - (-xL)*yU_inv, yL_inv*(-xcc) + (-xU)*ycv_inv - (-xU)*yL_inv)*NaN_check,
205+
IfElse.ifelse(yU_inv <= 0.0, min(yU_inv*xcv + xL*ycv_inv - xL*yU_inv, yL_inv*xcv + xL*ycv_inv - xU*yL_inv)*NaN_check,
206+
-max(yU_inv*(-xcc) + (-xL)*ycv_inv - (-xL)*yU_inv, yL_inv*(-xcv) + (-xU)*ycv_inv - (-xU)*yL_inv)*NaN_check)),
207+
IfElse.ifelse(yL_inv >= 0.0, min(xL*ycv_inv + yU_inv*xcc - yU_inv*xL, xU*ycc_inv + yL_inv*xcc - yL_inv*xU)*NaN_check,
208+
IfElse.ifelse(yU_inv <= 0.0, -max(xU*(-ycc_inv) + (-yL_inv)*xcv - (-yL_inv)*xU, xL*(-ycv_inv) + (-yU_inv)*xcv - (-yU_inv)*xL)*NaN_check,
209+
min(yL_inv*xcv + xU*ycc_inv - xU*yL_inv, yU_inv*xcc + xL*ycv_inv - xL*yU_inv)*NaN_check)))))
210+
return rcv, rcc
211+
end
212+
168213
function transform_rule(::McCormickTransform, ::typeof(min), zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL, yU, ycv, ycc)
169214
rcv = Equation(zcv, min(xcv, ycv))
170215
rcc = Equation(zcc, min(xcc, ycc))

src/transform/utilities.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ arity(a::Term{Real, Nothing}) = 1
55
arity(a::SymbolicUtils.Add) = length(a.dict) + (~iszero(a.coeff))
66
arity(a::SymbolicUtils.Mul) = length(a.dict) + (~isone(a.coeff))
77
arity(a::SymbolicUtils.Pow) = 2
8+
arity(a::SymbolicUtils.Div) = 2
89

910
op(a::Equation) = op(a.rhs)
1011
op(::SymbolicUtils.Add) = +
@@ -38,6 +39,12 @@ function sub_2(a::SymbolicUtils.Mul)
3839
sorted_dict = sort(collect(a.dict), by=x->string(x[1]))
3940
return sorted_dict[2].first
4041
end
42+
function sub_1(a::SymbolicUtils.Div)
43+
return a.num
44+
end
45+
function sub_2(a::SymbolicUtils.Div)
46+
return a.den
47+
end
4148
function sub_1(a::SymbolicUtils.Pow)
4249
return a.base
4350
end
@@ -235,7 +242,6 @@ end
235242
function _pull_vars(term::SymbolicUtils.Add, vars::Vector{Num}, strings::Vector{String})
236243
args = arguments(term)
237244
for arg in args
238-
# if (typeof(arg) == Term{Real, Nothing}) || (typeof(arg) == Sym{Real, Base.ImmutableDict{DataType, Any}})
239245
if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}})
240246
if ~(string(arg) in strings)
241247
push!(strings, string(arg))
@@ -253,7 +259,23 @@ end
253259
function _pull_vars(term::SymbolicUtils.Mul, vars::Vector{Num}, strings::Vector{String})
254260
args = arguments(term)
255261
for arg in args
256-
# if (typeof(arg) == Term{Real, Nothing}) || (typeof(arg) == Sym{Real, Base.ImmutableDict{DataType, Any}})
262+
if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}})
263+
if ~(string(arg) in strings)
264+
push!(strings, string(arg))
265+
push!(vars, arg)
266+
end
267+
elseif (typeof(arg) <: Int) || (typeof(arg) <: AbstractFloat)
268+
nothing
269+
else
270+
vars, strings = _pull_vars(arg, vars, strings)
271+
end
272+
end
273+
return vars, strings
274+
end
275+
276+
function _pull_vars(term::SymbolicUtils.Div, vars::Vector{Num}, strings::Vector{String})
277+
args = arguments(term)
278+
for arg in args
257279
if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}})
258280
if ~(string(arg) in strings)
259281
push!(strings, string(arg))
@@ -271,7 +293,6 @@ end
271293
function _pull_vars(term::SymbolicUtils.Pow, vars::Vector{Num}, strings::Vector{String})
272294
args = arguments(term)
273295
for arg in args
274-
# if (typeof(arg) == Term{Real, Nothing}) || (typeof(arg) == Sym{Real, Base.ImmutableDict{DataType, Any}})
275296
if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}})
276297
if ~(string(arg) in strings)
277298
push!(strings, string(arg))
@@ -289,7 +310,6 @@ end
289310
function _pull_vars(term::SymbolicUtils.Term{Real, Nothing}, vars::Vector{Num}, strings::Vector{String})
290311
args = arguments(term)
291312
for arg in args
292-
# if (typeof(arg) == Term{Real, Nothing}) || (typeof(arg) == Sym{Real, Base.ImmutableDict{DataType, Any}})
293313
if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}})
294314
if ~(string(arg) in strings)
295315
push!(strings, string(arg))
@@ -307,7 +327,6 @@ end
307327
function _pull_vars(term::SymbolicUtils.Term{Bool, Nothing}, vars::Vector{Num}, strings::Vector{String})
308328
args = arguments(term)
309329
for arg in args
310-
# if (typeof(arg) == Term{Real, Nothing}) || (typeof(arg) == Sym{Real, Base.ImmutableDict{DataType, Any}})
311330
if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}})
312331
if ~(string(arg) in strings)
313332
push!(strings, string(arg))
@@ -322,6 +341,10 @@ function _pull_vars(term::SymbolicUtils.Term{Bool, Nothing}, vars::Vector{Num},
322341
return vars, strings
323342
end
324343

344+
function _pull_vars(term::SymbolicUtils.Term{Float64, Nothing}, vars::Vector{Num}, strings::Vector{String})
345+
return vars, strings
346+
end
347+
325348
"""
326349
shrink_eqs(::Vector{Equation})
327350
shrink_eqs(::Vector{Equation}, ::Int)

test/runtests.jl

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ end
1010
to_compute = x*y
1111
mult_lo, mult_hi, mult_cv, mult_cc, order = all_evaluators(to_compute)
1212

13-
xMC = MC{1,NS}(-2.0, Interval(-3.0, -1.0), 1)
14-
yMC = MC{1,NS}(4.0, Interval(2.0, 6.0), 2)
13+
xMC = MC{2,NS}(-2.0, Interval(-3.0, -1.0), 1)
14+
yMC = MC{2,NS}(4.0, Interval(2.0, 6.0), 2)
1515
zMC = 0.5*xMC*yMC
1616
neg = zMC
1717
mix = zMC + 4.0
@@ -63,8 +63,8 @@ end
6363
to_compute = x+y
6464
add_lo, add_hi, add_cv, add_cc, order = all_evaluators(to_compute)
6565

66-
xMC = MC{1,NS}(-2.0, Interval(-3.0, -1.0), 1)
67-
yMC = MC{1,NS}(4.0, Interval(2.0, 6.0), 2)
66+
xMC = MC{2,NS}(-2.0, Interval(-3.0, -1.0), 1)
67+
yMC = MC{2,NS}(4.0, Interval(2.0, 6.0), 2)
6868
zMC = 0.5*xMC*yMC
6969
neg = zMC
7070
mix = zMC + 4.0
@@ -117,8 +117,8 @@ end
117117
to_compute = x-y
118118
sub_lo, sub_hi, sub_cv, sub_cc, order = all_evaluators(to_compute)
119119

120-
xMC = MC{1,NS}(-2.0, Interval(-3.0, -1.0), 1)
121-
yMC = MC{1,NS}(4.0, Interval(2.0, 6.0), 2)
120+
xMC = MC{2,NS}(-2.0, Interval(-3.0, -1.0), 1)
121+
yMC = MC{2,NS}(4.0, Interval(2.0, 6.0), 2)
122122
zMC = 0.5*xMC*yMC
123123
neg = zMC
124124
mix = zMC + 4.0
@@ -164,3 +164,53 @@ end
164164
@test eval_check(sub_cc, pos, mix) == (pos-mix).cc
165165
@test eval_check(sub_cc, pos, pos) == (pos-pos).cc
166166
end
167+
168+
169+
@testset "Division" begin
170+
@variables x, y
171+
to_compute = x/y
172+
div_lo, div_hi, div_cv, div_cc, order = all_evaluators(to_compute)
173+
174+
xMC = MC{2,NS}(-2.0, Interval(-3.0, -1.0), 1)
175+
yMC = MC{2,NS}(4.0, Interval(2.0, 6.0), 2)
176+
zMC = 0.5*xMC*yMC
177+
neg = zMC
178+
mix = zMC + 4.0
179+
pos = zMC + 10.0
180+
181+
@test abs(eval_check(div_lo, 0.99*neg, neg) - (0.99*neg/neg).Intv.lo) < 1e-15
182+
@test isnan(eval_check(div_lo, neg, mix))
183+
@test eval_check(div_lo, neg, pos) == (neg/pos).Intv.lo
184+
@test eval_check(div_lo, mix, neg) == (mix/neg).Intv.lo
185+
@test eval_check(div_lo, mix, pos) == (mix/pos).Intv.lo
186+
@test eval_check(div_lo, pos, neg) == (pos/neg).Intv.lo
187+
@test isnan(eval_check(div_lo, pos, mix))
188+
@test abs(eval_check(div_lo, 0.99*pos, pos) - (0.99*pos/pos).Intv.lo) < 1e-15
189+
190+
@test abs(eval_check(div_hi, 0.99*neg, neg) - (0.99*neg/neg).Intv.hi) < 1e-15
191+
@test isnan(eval_check(div_hi, neg, mix))
192+
@test eval_check(div_hi, neg, pos) == (neg/pos).Intv.hi
193+
@test eval_check(div_hi, mix, neg) == (mix/neg).Intv.hi
194+
@test eval_check(div_hi, mix, pos) == (mix/pos).Intv.hi
195+
@test eval_check(div_hi, pos, neg) == (pos/neg).Intv.hi
196+
@test isnan(eval_check(div_hi, pos, mix))
197+
@test abs(eval_check(div_hi, 0.99*pos, pos) - (0.99*pos/pos).Intv.hi) < 1e-15
198+
199+
@test abs(eval_check(div_cv, 0.99*neg, neg) - (0.99*neg/neg).cv) < 1e-15
200+
@test isnan(eval_check(div_cv, neg, mix))
201+
@test eval_check(div_cv, neg, pos) == (neg/pos).cv
202+
@test eval_check(div_cv, mix, neg) == (mix/neg).cv
203+
@test eval_check(div_cv, mix, pos) == (mix/pos).cv
204+
@test eval_check(div_cv, pos, neg) == (pos/neg).cv
205+
@test isnan(eval_check(div_cv, pos, mix))
206+
@test abs(eval_check(div_cv, 0.99*pos, pos) - (0.99*pos/pos).cv) < 1e-15
207+
208+
@test abs(eval_check(div_cc, 0.99*neg, neg) - (0.99*neg/neg).cc) < 1e-15
209+
@test isnan(eval_check(div_cc, neg, mix))
210+
@test eval_check(div_cc, neg, pos) == (neg/pos).cc
211+
@test eval_check(div_cc, mix, neg) == (mix/neg).cc
212+
@test eval_check(div_cc, mix, pos) == (mix/pos).cc
213+
@test abs(eval_check(div_cc, pos, neg) - (pos/neg).cc) < 1e-15
214+
@test isnan(eval_check(div_cc, pos, mix))
215+
@test abs(eval_check(div_cc, 0.99*pos, pos) - (0.99*pos/pos).cc) < 1e-15
216+
end

0 commit comments

Comments
 (0)