@@ -20,6 +20,158 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
2020 :row ,
2121]
2222
23+ function _validate_register_assumptions (
24+ f:: Function ,
25+ name:: Symbol ,
26+ nb_args:: Integer ,
27+ )
28+ # Assumption 1: check that `f` can be called with `Float64` arguments.
29+ arg = nb_args == 1 ? 0.0 : zeros (nb_args)
30+ if hasmethod (f, Tuple{typeof (arg)})
31+ y = f (arg)
32+ else
33+ error (
34+ " Unable to register the function :$name .\n\n " *
35+ " The function must be able to be called with $nb_args Float64 " *
36+ " arguments, but no method was found for this." ,
37+ )
38+ end
39+ if ! (y isa Real)
40+ error (
41+ " Expected return type of `Float64` from the user-defined " *
42+ " function :$(name) , but got `$(typeof (y)) `." ,
43+ )
44+ end
45+ # Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
46+ try
47+ if nb_args == 1
48+ ForwardDiff. derivative (f, 0.0 )
49+ else
50+ ForwardDiff. gradient (x -> f (x... ), zeros (nb_args))
51+ end
52+ catch err
53+ if err isa MethodError
54+ error (
55+ " Unable to register the function :$name .\n\n " *
56+ _FORWARD_DIFF_METHOD_ERROR_HELPER,
57+ )
58+ end
59+ # We hit some other error, perhaps we called a function like log(-1).
60+ # Ignore for now, and hope that a useful error is shown to the user
61+ # during the solve.
62+ end
63+ return
64+ end
65+
66+ function _checked_derivative (f:: F , op:: Symbol ) where {F}
67+ return function (x)
68+ try
69+ return ForwardDiff. derivative (f, x)
70+ catch err
71+ _intercept_ForwardDiff_MethodError (err, op)
72+ end
73+ end
74+ end
75+
76+ """
77+ check_return_type(::Type{T}, ret::S) where {T,S}
78+
79+ Overload this method for new types `S` to throw an informative error if a
80+ user-defined function returns the type `S` instead of `T`.
81+ """
82+ check_return_type (:: Type{T} , ret:: T ) where {T} = nothing
83+
84+ function check_return_type (:: Type{T} , ret) where {T}
85+ return error (
86+ " Expected return type of $T from a user-defined function, but got " *
87+ " $(typeof (ret)) ." ,
88+ )
89+ end
90+
91+ struct _UnivariateOperator{F,F′,F′′}
92+ f:: F
93+ f′:: F′
94+ f′′:: F′′
95+ function _UnivariateOperator (
96+ f:: Function ,
97+ f′:: Function ,
98+ f′′:: Union{Nothing,Function} = nothing ,
99+ )
100+ return new {typeof(f),typeof(f′),typeof(f′′)} (f, f′, f′′)
101+ end
102+ end
103+
104+ function _UnivariateOperator (op:: Symbol , f:: Function )
105+ _validate_register_assumptions (f, op, 1 )
106+ f′ = _checked_derivative (f, op)
107+ return _UnivariateOperator (op, f, f′)
108+ end
109+
110+ function _UnivariateOperator (op:: Symbol , f:: Function , f′:: Function )
111+ try
112+ _validate_register_assumptions (f′, op, 1 )
113+ f′′ = _checked_derivative (f′, op)
114+ return _UnivariateOperator (f, f′, f′′)
115+ catch
116+ return _UnivariateOperator (f, f′, nothing )
117+ end
118+ end
119+
120+ function _UnivariateOperator (:: Symbol , f:: Function , f′:: Function , f′′:: Function )
121+ return _UnivariateOperator (f, f′, f′′)
122+ end
123+
124+ struct OperatorRegistry
125+ # NODE_CALL_UNIVARIATE
126+ univariate_operators:: Vector{Symbol}
127+ univariate_operator_to_id:: Dict{Symbol,Int}
128+ univariate_user_operator_start:: Int
129+ registered_univariate_operators:: Vector{_UnivariateOperator}
130+ # NODE_CALL_MULTIVARIATE
131+ multivariate_operators:: Vector{Symbol}
132+ multivariate_operator_to_id:: Dict{Symbol,Int}
133+ multivariate_user_operator_start:: Int
134+ registered_multivariate_operators:: Vector {
135+ MOI. Nonlinear. _MultivariateOperator,
136+ }
137+ # NODE_LOGIC
138+ logic_operators:: Vector{Symbol}
139+ logic_operator_to_id:: Dict{Symbol,Int}
140+ # NODE_COMPARISON
141+ comparison_operators:: Vector{Symbol}
142+ comparison_operator_to_id:: Dict{Symbol,Int}
143+ function OperatorRegistry ()
144+ univariate_operators = copy (MOI. Nonlinear. DEFAULT_UNIVARIATE_OPERATORS)
145+ multivariate_operators = copy (DEFAULT_MULTIVARIATE_OPERATORS)
146+ logic_operators = [:&& , :|| ]
147+ comparison_operators = [:<= , :(== ), :>= , :< , :> ]
148+ return new (
149+ # NODE_CALL_UNIVARIATE
150+ univariate_operators,
151+ Dict {Symbol,Int} (
152+ op => i for (i, op) in enumerate (univariate_operators)
153+ ),
154+ length (univariate_operators),
155+ _UnivariateOperator[],
156+ # NODE_CALL
157+ multivariate_operators,
158+ Dict {Symbol,Int} (
159+ op => i for (i, op) in enumerate (multivariate_operators)
160+ ),
161+ length (multivariate_operators),
162+ MOI. Nonlinear. _MultivariateOperator[],
163+ # NODE_LOGIC
164+ logic_operators,
165+ Dict {Symbol,Int} (op => i for (i, op) in enumerate (logic_operators)),
166+ # NODE_COMPARISON
167+ comparison_operators,
168+ Dict {Symbol,Int} (
169+ op => i for (i, op) in enumerate (comparison_operators)
170+ ),
171+ )
172+ end
173+ end
174+
23175function eval_logic_function (
24176 :: OperatorRegistry ,
25177 op:: Symbol ,
@@ -34,6 +186,23 @@ function eval_logic_function(
34186 end
35187end
36188
189+ function _generate_eval_univariate ()
190+ exprs = map (Nonlinear. DEFAULT_UNIVARIATE_OPERATORS) do op
191+ return :(
192+ return (
193+ value_deriv_and_second ($ op, x)[1 ],
194+ value_deriv_and_second ($ op, x)[2 ],
195+ )
196+ )
197+ end
198+ return Nonlinear. _create_binary_switch (1 : length (exprs), exprs)
199+ end
200+
201+ @eval @inline function _eval_univariate (id, x:: T ) where {T}
202+ $ (_generate_eval_univariate ())
203+ return error (" Invalid id for univariate operator: $id " )
204+ end
205+
37206function eval_multivariate_function (
38207 registry:: OperatorRegistry ,
39208 op:: Symbol ,
@@ -165,17 +334,44 @@ function eval_multivariate_hessian(
165334 return true
166335end
167336
337+ function eval_univariate_function (operator:: _UnivariateOperator , x:: T ) where {T}
338+ ret = operator. f (x)
339+ check_return_type (T, ret)
340+ return ret:: T
341+ end
342+
343+ function eval_univariate_gradient (operator:: _UnivariateOperator , x:: T ) where {T}
344+ ret = operator. f′ (x)
345+ check_return_type (T, ret)
346+ return ret:: T
347+ end
348+
349+ function eval_univariate_hessian (operator:: _UnivariateOperator , x:: T ) where {T}
350+ ret = operator. f′′ (x)
351+ check_return_type (T, ret)
352+ return ret:: T
353+ end
354+
355+ function eval_univariate_function_and_gradient (
356+ operator:: _UnivariateOperator ,
357+ x:: T ,
358+ ) where {T}
359+ ret_f = eval_univariate_function (operator, x)
360+ ret_f′ = eval_univariate_gradient (operator, x)
361+ return ret_f, ret_f′
362+ end
363+
168364function eval_univariate_function_and_gradient (
169365 registry:: OperatorRegistry ,
170366 id:: Integer ,
171367 x:: T ,
172368) where {T}
173369 if id <= registry. univariate_user_operator_start
174- return Nonlinear . _eval_univariate (id, x):: Tuple{T,T}
370+ return _eval_univariate (id, x):: Tuple{T,T}
175371 end
176372 offset = id - registry. univariate_user_operator_start
177373 operator = registry. registered_univariate_operators[offset]
178- return Nonlinear . eval_univariate_function_and_gradient (operator, x)
374+ return eval_univariate_function_and_gradient (operator, x)
179375end
180376
181377function eval_multivariate_gradient (
0 commit comments