@@ -20,6 +20,160 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
2020 :row ,
2121]
2222
23+ function _validate_register_assumptions (
24+ f:: Function ,
25+ name:: Symbol ,
26+ dimension:: Integer ,
27+ )
28+ # Assumption 1: check that `f` can be called with `Float64` arguments.
29+ y = 0.0
30+ try
31+ if dimension == 1
32+ y = f (0.0 )
33+ else
34+ y = f (zeros (dimension)... )
35+ end
36+ catch
37+ # We hit some other error, perhaps we called a function like log(-1).
38+ # Ignore for now, and hope that a useful error is shown to the user
39+ # during the solve.
40+ end
41+ if ! (y isa Real)
42+ error (
43+ " Expected return type of `Float64` from the user-defined " *
44+ " function :$(name) , but got `$(typeof (y)) `." ,
45+ )
46+ end
47+ # Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
48+ try
49+ if dimension == 1
50+ ForwardDiff. derivative (f, 0.0 )
51+ else
52+ ForwardDiff. gradient (x -> f (x... ), zeros (dimension))
53+ end
54+ catch err
55+ if err isa MethodError
56+ error (
57+ " Unable to register the function :$name .\n\n " *
58+ _FORWARD_DIFF_METHOD_ERROR_HELPER,
59+ )
60+ end
61+ # We hit some other error, perhaps we called a function like log(-1).
62+ # Ignore for now, and hope that a useful error is shown to the user
63+ # during the solve.
64+ end
65+ return
66+ end
67+
68+ function _checked_derivative (f:: F , op:: Symbol ) where {F}
69+ return function (x)
70+ try
71+ return ForwardDiff. derivative (f, x)
72+ catch err
73+ _intercept_ForwardDiff_MethodError (err, op)
74+ end
75+ end
76+ end
77+
78+ """
79+ check_return_type(::Type{T}, ret::S) where {T,S}
80+
81+ Overload this method for new types `S` to throw an informative error if a
82+ user-defined function returns the type `S` instead of `T`.
83+ """
84+ check_return_type (:: Type{T} , ret:: T ) where {T} = nothing
85+
86+ function check_return_type (:: Type{T} , ret) where {T}
87+ return error (
88+ " Expected return type of $T from a user-defined function, but got " *
89+ " $(typeof (ret)) ." ,
90+ )
91+ end
92+
93+ struct _UnivariateOperator{F,F′,F′′}
94+ f:: F
95+ f′:: F′
96+ f′′:: F′′
97+ function _UnivariateOperator (
98+ f:: Function ,
99+ f′:: Function ,
100+ f′′:: Union{Nothing,Function} = nothing ,
101+ )
102+ return new {typeof(f),typeof(f′),typeof(f′′)} (f, f′, f′′)
103+ end
104+ end
105+
106+ function _UnivariateOperator (op:: Symbol , f:: Function )
107+ _validate_register_assumptions (f, op, 1 )
108+ f′ = _checked_derivative (f, op)
109+ return _UnivariateOperator (op, f, f′)
110+ end
111+
112+ function _UnivariateOperator (op:: Symbol , f:: Function , f′:: Function )
113+ try
114+ _validate_register_assumptions (f′, op, 1 )
115+ f′′ = _checked_derivative (f′, op)
116+ return _UnivariateOperator (f, f′, f′′)
117+ catch
118+ return _UnivariateOperator (f, f′, nothing )
119+ end
120+ end
121+
122+ function _UnivariateOperator (:: Symbol , f:: Function , f′:: Function , f′′:: Function )
123+ return _UnivariateOperator (f, f′, f′′)
124+ end
125+
126+ struct OperatorRegistry
127+ # NODE_CALL_UNIVARIATE
128+ univariate_operators:: Vector{Symbol}
129+ univariate_operator_to_id:: Dict{Symbol,Int}
130+ univariate_user_operator_start:: Int
131+ registered_univariate_operators:: Vector{_UnivariateOperator}
132+ # NODE_CALL_MULTIVARIATE
133+ multivariate_operators:: Vector{Symbol}
134+ multivariate_operator_to_id:: Dict{Symbol,Int}
135+ multivariate_user_operator_start:: Int
136+ registered_multivariate_operators:: Vector {
137+ MOI. Nonlinear. _MultivariateOperator,
138+ }
139+ # NODE_LOGIC
140+ logic_operators:: Vector{Symbol}
141+ logic_operator_to_id:: Dict{Symbol,Int}
142+ # NODE_COMPARISON
143+ comparison_operators:: Vector{Symbol}
144+ comparison_operator_to_id:: Dict{Symbol,Int}
145+ function OperatorRegistry ()
146+ univariate_operators = copy (MOI. Nonlinear. DEFAULT_UNIVARIATE_OPERATORS)
147+ multivariate_operators = copy (DEFAULT_MULTIVARIATE_OPERATORS)
148+ logic_operators = [:&& , :|| ]
149+ comparison_operators = [:<= , :(== ), :>= , :< , :> ]
150+ return new (
151+ # NODE_CALL_UNIVARIATE
152+ univariate_operators,
153+ Dict {Symbol,Int} (
154+ op => i for (i, op) in enumerate (univariate_operators)
155+ ),
156+ length (univariate_operators),
157+ _UnivariateOperator[],
158+ # NODE_CALL
159+ multivariate_operators,
160+ Dict {Symbol,Int} (
161+ op => i for (i, op) in enumerate (multivariate_operators)
162+ ),
163+ length (multivariate_operators),
164+ MOI. Nonlinear. _MultivariateOperator[],
165+ # NODE_LOGIC
166+ logic_operators,
167+ Dict {Symbol,Int} (op => i for (i, op) in enumerate (logic_operators)),
168+ # NODE_COMPARISON
169+ comparison_operators,
170+ Dict {Symbol,Int} (
171+ op => i for (i, op) in enumerate (comparison_operators)
172+ ),
173+ )
174+ end
175+ end
176+
23177function eval_logic_function (
24178 :: OperatorRegistry ,
25179 op:: Symbol ,
36190
37191function _generate_eval_univariate ()
38192 exprs = map (Nonlinear. DEFAULT_UNIVARIATE_OPERATORS) do op
39- return :(return (value_deriv_and_second ($ op, x)[1 ], value_deriv_and_second ($ op, x)[2 ]))
193+ return :(
194+ return (
195+ value_deriv_and_second ($ op, x)[1 ],
196+ value_deriv_and_second ($ op, x)[2 ],
197+ )
198+ )
40199 end
41200 return Nonlinear. _create_binary_switch (1 : length (exprs), exprs)
42201end
@@ -177,109 +336,6 @@ function eval_multivariate_hessian(
177336 return true
178337end
179338
180- function _validate_register_assumptions (
181- f:: Function ,
182- name:: Symbol ,
183- dimension:: Integer ,
184- )
185- # Assumption 1: check that `f` can be called with `Float64` arguments.
186- y = 0.0
187- try
188- if dimension == 1
189- y = f (0.0 )
190- else
191- y = f (zeros (dimension)... )
192- end
193- catch
194- # We hit some other error, perhaps we called a function like log(-1).
195- # Ignore for now, and hope that a useful error is shown to the user
196- # during the solve.
197- end
198- if ! (y isa Real)
199- error (
200- " Expected return type of `Float64` from the user-defined " *
201- " function :$(name) , but got `$(typeof (y)) `." ,
202- )
203- end
204- # Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
205- try
206- if dimension == 1
207- ForwardDiff. derivative (f, 0.0 )
208- else
209- ForwardDiff. gradient (x -> f (x... ), zeros (dimension))
210- end
211- catch err
212- if err isa MethodError
213- error (
214- " Unable to register the function :$name .\n\n " *
215- _FORWARD_DIFF_METHOD_ERROR_HELPER,
216- )
217- end
218- # We hit some other error, perhaps we called a function like log(-1).
219- # Ignore for now, and hope that a useful error is shown to the user
220- # during the solve.
221- end
222- return
223- end
224-
225- function _checked_derivative (f:: F , op:: Symbol ) where {F}
226- return function (x)
227- try
228- return ForwardDiff. derivative (f, x)
229- catch err
230- _intercept_ForwardDiff_MethodError (err, op)
231- end
232- end
233- end
234-
235- """
236- check_return_type(::Type{T}, ret::S) where {T,S}
237-
238- Overload this method for new types `S` to throw an informative error if a
239- user-defined function returns the type `S` instead of `T`.
240- """
241- check_return_type (:: Type{T} , ret:: T ) where {T} = nothing
242-
243- function check_return_type (:: Type{T} , ret) where {T}
244- return error (
245- " Expected return type of $T from a user-defined function, but got " *
246- " $(typeof (ret)) ." ,
247- )
248- end
249-
250- struct _UnivariateOperator{F,F′,F′′}
251- f:: F
252- f′:: F′
253- f′′:: F′′
254- function _UnivariateOperator (
255- f:: Function ,
256- f′:: Function ,
257- f′′:: Union{Nothing,Function} = nothing ,
258- )
259- return new {typeof(f),typeof(f′),typeof(f′′)} (f, f′, f′′)
260- end
261- end
262-
263- function _UnivariateOperator (op:: Symbol , f:: Function )
264- _validate_register_assumptions (f, op, 1 )
265- f′ = _checked_derivative (f, op)
266- return _UnivariateOperator (op, f, f′)
267- end
268-
269- function _UnivariateOperator (op:: Symbol , f:: Function , f′:: Function )
270- try
271- _validate_register_assumptions (f′, op, 1 )
272- f′′ = _checked_derivative (f′, op)
273- return _UnivariateOperator (f, f′, f′′)
274- catch
275- return _UnivariateOperator (f, f′, nothing )
276- end
277- end
278-
279- function _UnivariateOperator (:: Symbol , f:: Function , f′:: Function , f′′:: Function )
280- return _UnivariateOperator (f, f′, f′′)
281- end
282-
283339function eval_univariate_function (operator:: _UnivariateOperator , x:: T ) where {T}
284340 ret = operator. f (x)
285341 check_return_type (T, ret)
0 commit comments