Skip to content

Commit f3ca2a5

Browse files
committed
[Nonlinear] add support for simplifying NonlinearFunction
1 parent 57d0f65 commit f3ca2a5

File tree

3 files changed

+452
-0
lines changed

3 files changed

+452
-0
lines changed

src/Nonlinear/Nonlinear.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@ include("model.jl")
4848
include("evaluator.jl")
4949

5050
include("ReverseAD/ReverseAD.jl")
51+
include("SymbolicAD/SymbolicAD.jl")
5152

5253
end # module
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright (c) 2017: Miles Lubin and contributors
2+
# Copyright (c) 2017: Google Inc.
3+
#
4+
# Use of this source code is governed by an MIT-style license that can be found
5+
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
6+
7+
module SymbolicAD
8+
9+
import MathOptInterface as MOI
10+
11+
"""
12+
simplify(f)
13+
14+
Return a simplified version of the function `f`.
15+
16+
!!! warning
17+
This function is not type stable by design.
18+
"""
19+
simplify(f) = f
20+
21+
function simplify(f::MOI.ScalarAffineFunction{T}) where {T}
22+
f = MOI.Utilities.canonical(f)
23+
if isempty(f.terms)
24+
return f.constant
25+
end
26+
return f
27+
end
28+
29+
function simplify(f::MOI.ScalarQuadraticFunction{T}) where {T}
30+
f = MOI.Utilities.canonical(f)
31+
if isempty(f.quadratic_terms)
32+
return simplify(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
33+
end
34+
return f
35+
end
36+
37+
function simplify(f::MOI.ScalarNonlinearFunction)
38+
for i in 1:length(f.args)
39+
f.args[i] = simplify(f.args[i])
40+
end
41+
return _eval_if_constant(simplify(Val(f.head), f))
42+
end
43+
44+
function simplify(f::MOI.VectorAffineFunction{T}) where {T}
45+
f = MOI.Utilities.canonical(f)
46+
if isempty(f.terms)
47+
return f.constant
48+
end
49+
return f
50+
end
51+
52+
function simplify(f::MOI.VectorQuadraticFunction{T}) where {T}
53+
f = MOI.Utilities.canonical(f)
54+
if isempty(f.quadratic_terms)
55+
return simplify(MOI.VectorAffineFunction(f.affine_terms, f.constants))
56+
end
57+
return f
58+
end
59+
60+
function simplify(f::MOI.VectorNonlinearFunction)
61+
return MOI.VectorNonlinearFunction(simplify.(f.rows))
62+
end
63+
64+
# If a ScalarNonlinearFunction has only constant arguments, we should return
65+
# the vaålue.
66+
67+
_isnum(::Any) = false
68+
69+
_isnum(::Union{Bool,Integer,Float64}) = true
70+
71+
function _eval_if_constant(f::MOI.ScalarNonlinearFunction)
72+
if all(_isnum, f.args) && hasproperty(Base, f.head)
73+
return getproperty(Base, f.head)(f.args...)
74+
end
75+
return f
76+
end
77+
78+
_eval_if_constant(f) = f
79+
80+
_iszero(x::Any) = _isnum(x) && iszero(x)
81+
82+
_isone(x::Any) = _isnum(x) && isone(x)
83+
84+
"""
85+
_isexpr(f::Any, head::Symbol[, n::Int])
86+
87+
Return `true` if `f` is a `ScalarNonlinearFunction` with head `head` and, if
88+
specified, `n` arguments.
89+
"""
90+
_isexpr(::Any, ::Symbol, n::Int = 0) = false
91+
92+
_isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol) = f.head == head
93+
94+
function _isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol, n::Int)
95+
return _isexpr(f, head) && length(f.args) == n
96+
end
97+
98+
"""
99+
simplify(::Val{head}, f::MOI.ScalarNonlinearFunction)
100+
101+
Return a simplified version of `f` where the head of `f` is `head`.
102+
103+
Implementing this method enables custom simplification rules for different
104+
operators without needing a giant switch statement.
105+
"""
106+
simplify(::Val, f::MOI.ScalarNonlinearFunction) = f
107+
108+
function simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
109+
new_args = Any[]
110+
first_constant = 0
111+
for arg in f.args
112+
if _isexpr(arg, :*)
113+
# If the child is a :*, lift its arguments to the parent
114+
append!(new_args, arg.args)
115+
elseif _iszero(arg)
116+
# If any argument is zero, the entire expression must be false
117+
return false
118+
elseif _isone(arg)
119+
# Skip any arguments that are one
120+
elseif arg isa Real
121+
# Collect all constant arguments into a single value
122+
if first_constant == 0
123+
push!(new_args, arg)
124+
first_constant = length(new_args)
125+
else
126+
new_args[first_constant] *= arg
127+
end
128+
else
129+
push!(new_args, arg)
130+
end
131+
end
132+
if isempty(new_args)
133+
return true
134+
elseif length(new_args) == 1
135+
return only(new_args)
136+
end
137+
return MOI.ScalarNonlinearFunction(:*, new_args)
138+
end
139+
140+
function simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
141+
if length(f.args) == 1
142+
# +(x) -> x
143+
return only(f.args)
144+
elseif length(f.args) == 2 && _isexpr(f.args[2], :-, 1)
145+
# +(x, -y) -> -(x, y)
146+
return MOI.ScalarNonlinearFunction(
147+
:-,
148+
Any[f.args[1], f.args[2].args[1]],
149+
)
150+
end
151+
new_args = Any[]
152+
first_constant = 0
153+
for arg in f.args
154+
if _isexpr(arg, :+)
155+
# If a child is a :+, lift its arguments to the parent
156+
append!(new_args, arg.args)
157+
elseif _iszero(arg)
158+
# Skip any zero arguments
159+
elseif arg isa Real
160+
# Collect all constant arguments into a single value
161+
if first_constant == 0
162+
push!(new_args, arg)
163+
first_constant = length(new_args)
164+
else
165+
new_args[first_constant] += arg
166+
end
167+
else
168+
push!(new_args, arg)
169+
end
170+
end
171+
if isempty(new_args)
172+
# +() -> false
173+
return false
174+
elseif length(new_args) == 1
175+
# +(x) -> x
176+
return only(new_args)
177+
end
178+
return MOI.ScalarNonlinearFunction(:+, new_args)
179+
end
180+
181+
function simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
182+
if length(f.args) == 1
183+
if _isexpr(f.args[1], :-, 1)
184+
# -(-(x)) => x
185+
return f.args[1].args[1]
186+
end
187+
elseif length(f.args) == 2
188+
if _iszero(f.args[1])
189+
# 0 - x => -x
190+
return MOI.ScalarNonlinearFunction(:-, Any[f.args[2]])
191+
elseif _iszero(f.args[2])
192+
# x - 0 => x
193+
return f.args[1]
194+
elseif f.args[1] == f.args[2]
195+
# x - x => 0
196+
return false
197+
elseif _isexpr(f.args[2], :-, 1)
198+
# x - -(y) => x + y
199+
return MOI.ScalarNonlinearFunction(
200+
:+,
201+
Any[f.args[1], f.args[2].args[1]],
202+
)
203+
end
204+
end
205+
return f
206+
end
207+
208+
function simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
209+
if _iszero(f.args[2])
210+
# x^0 => 1
211+
return true
212+
elseif _isone(f.args[2])
213+
# x^1 => x
214+
return f.args[1]
215+
elseif _iszero(f.args[1])
216+
# 0^x => 0
217+
return false
218+
elseif _isone(f.args[1])
219+
# 1^x => 1
220+
return true
221+
end
222+
return f
223+
end
224+
225+
end # module

0 commit comments

Comments
 (0)