Skip to content

Commit 9b7c3aa

Browse files
committed
Fix include problems and format
1 parent c4491c7 commit 9b7c3aa

File tree

4 files changed

+190
-170
lines changed

4 files changed

+190
-170
lines changed

src/ArrayDiff.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ import NaNMath:
3636
include("Coloring/Coloring.jl")
3737
include("graph_tools.jl")
3838
include("sizes.jl")
39+
include("univariate_expressions.jl")
40+
include("operators.jl")
3941
include("types.jl")
4042
include("utils.jl")
4143

4244
include("reverse_mode.jl")
4345
include("forward_over_reverse.jl")
4446
include("mathoptinterface_api.jl")
45-
include("operators.jl")
4647
include("model.jl")
4748
include("parse.jl")
4849
include("evaluator.jl")
49-
include("univariate_expressions.jl")
5050

5151
"""
5252
Mode() <: AbstractAutomaticDifferentiation

src/operators.jl

Lines changed: 160 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,160 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
2020
:row,
2121
]
2222

23+
function _validate_register_assumptions(
24+
f::Function,
25+
name::Symbol,
26+
dimension::Integer,
27+
)
28+
# Assumption 1: check that `f` can be called with `Float64` arguments.
29+
y = 0.0
30+
try
31+
if dimension == 1
32+
y = f(0.0)
33+
else
34+
y = f(zeros(dimension)...)
35+
end
36+
catch
37+
# We hit some other error, perhaps we called a function like log(-1).
38+
# Ignore for now, and hope that a useful error is shown to the user
39+
# during the solve.
40+
end
41+
if !(y isa Real)
42+
error(
43+
"Expected return type of `Float64` from the user-defined " *
44+
"function :$(name), but got `$(typeof(y))`.",
45+
)
46+
end
47+
# Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
48+
try
49+
if dimension == 1
50+
ForwardDiff.derivative(f, 0.0)
51+
else
52+
ForwardDiff.gradient(x -> f(x...), zeros(dimension))
53+
end
54+
catch err
55+
if err isa MethodError
56+
error(
57+
"Unable to register the function :$name.\n\n" *
58+
_FORWARD_DIFF_METHOD_ERROR_HELPER,
59+
)
60+
end
61+
# We hit some other error, perhaps we called a function like log(-1).
62+
# Ignore for now, and hope that a useful error is shown to the user
63+
# during the solve.
64+
end
65+
return
66+
end
67+
68+
function _checked_derivative(f::F, op::Symbol) where {F}
69+
return function (x)
70+
try
71+
return ForwardDiff.derivative(f, x)
72+
catch err
73+
_intercept_ForwardDiff_MethodError(err, op)
74+
end
75+
end
76+
end
77+
78+
"""
79+
check_return_type(::Type{T}, ret::S) where {T,S}
80+
81+
Overload this method for new types `S` to throw an informative error if a
82+
user-defined function returns the type `S` instead of `T`.
83+
"""
84+
check_return_type(::Type{T}, ret::T) where {T} = nothing
85+
86+
function check_return_type(::Type{T}, ret) where {T}
87+
return error(
88+
"Expected return type of $T from a user-defined function, but got " *
89+
"$(typeof(ret)).",
90+
)
91+
end
92+
93+
struct _UnivariateOperator{F,F′,F′′}
94+
f::F
95+
f′::F′
96+
f′′::F′′
97+
function _UnivariateOperator(
98+
f::Function,
99+
f′::Function,
100+
f′′::Union{Nothing,Function} = nothing,
101+
)
102+
return new{typeof(f),typeof(f′),typeof(f′′)}(f, f′, f′′)
103+
end
104+
end
105+
106+
function _UnivariateOperator(op::Symbol, f::Function)
107+
_validate_register_assumptions(f, op, 1)
108+
f′ = _checked_derivative(f, op)
109+
return _UnivariateOperator(op, f, f′)
110+
end
111+
112+
function _UnivariateOperator(op::Symbol, f::Function, f′::Function)
113+
try
114+
_validate_register_assumptions(f′, op, 1)
115+
f′′ = _checked_derivative(f′, op)
116+
return _UnivariateOperator(f, f′, f′′)
117+
catch
118+
return _UnivariateOperator(f, f′, nothing)
119+
end
120+
end
121+
122+
function _UnivariateOperator(::Symbol, f::Function, f′::Function, f′′::Function)
123+
return _UnivariateOperator(f, f′, f′′)
124+
end
125+
126+
struct OperatorRegistry
127+
# NODE_CALL_UNIVARIATE
128+
univariate_operators::Vector{Symbol}
129+
univariate_operator_to_id::Dict{Symbol,Int}
130+
univariate_user_operator_start::Int
131+
registered_univariate_operators::Vector{_UnivariateOperator}
132+
# NODE_CALL_MULTIVARIATE
133+
multivariate_operators::Vector{Symbol}
134+
multivariate_operator_to_id::Dict{Symbol,Int}
135+
multivariate_user_operator_start::Int
136+
registered_multivariate_operators::Vector{
137+
MOI.Nonlinear._MultivariateOperator,
138+
}
139+
# NODE_LOGIC
140+
logic_operators::Vector{Symbol}
141+
logic_operator_to_id::Dict{Symbol,Int}
142+
# NODE_COMPARISON
143+
comparison_operators::Vector{Symbol}
144+
comparison_operator_to_id::Dict{Symbol,Int}
145+
function OperatorRegistry()
146+
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
147+
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
148+
logic_operators = [:&&, :||]
149+
comparison_operators = [:<=, :(==), :>=, :<, :>]
150+
return new(
151+
# NODE_CALL_UNIVARIATE
152+
univariate_operators,
153+
Dict{Symbol,Int}(
154+
op => i for (i, op) in enumerate(univariate_operators)
155+
),
156+
length(univariate_operators),
157+
_UnivariateOperator[],
158+
# NODE_CALL
159+
multivariate_operators,
160+
Dict{Symbol,Int}(
161+
op => i for (i, op) in enumerate(multivariate_operators)
162+
),
163+
length(multivariate_operators),
164+
MOI.Nonlinear._MultivariateOperator[],
165+
# NODE_LOGIC
166+
logic_operators,
167+
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
168+
# NODE_COMPARISON
169+
comparison_operators,
170+
Dict{Symbol,Int}(
171+
op => i for (i, op) in enumerate(comparison_operators)
172+
),
173+
)
174+
end
175+
end
176+
23177
function eval_logic_function(
24178
::OperatorRegistry,
25179
op::Symbol,
@@ -36,7 +190,12 @@ end
36190

37191
function _generate_eval_univariate()
38192
exprs = map(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS) do op
39-
return :(return (value_deriv_and_second($op, x)[1], value_deriv_and_second($op, x)[2]))
193+
return :(
194+
return (
195+
value_deriv_and_second($op, x)[1],
196+
value_deriv_and_second($op, x)[2],
197+
)
198+
)
40199
end
41200
return Nonlinear._create_binary_switch(1:length(exprs), exprs)
42201
end
@@ -177,109 +336,6 @@ function eval_multivariate_hessian(
177336
return true
178337
end
179338

180-
function _validate_register_assumptions(
181-
f::Function,
182-
name::Symbol,
183-
dimension::Integer,
184-
)
185-
# Assumption 1: check that `f` can be called with `Float64` arguments.
186-
y = 0.0
187-
try
188-
if dimension == 1
189-
y = f(0.0)
190-
else
191-
y = f(zeros(dimension)...)
192-
end
193-
catch
194-
# We hit some other error, perhaps we called a function like log(-1).
195-
# Ignore for now, and hope that a useful error is shown to the user
196-
# during the solve.
197-
end
198-
if !(y isa Real)
199-
error(
200-
"Expected return type of `Float64` from the user-defined " *
201-
"function :$(name), but got `$(typeof(y))`.",
202-
)
203-
end
204-
# Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
205-
try
206-
if dimension == 1
207-
ForwardDiff.derivative(f, 0.0)
208-
else
209-
ForwardDiff.gradient(x -> f(x...), zeros(dimension))
210-
end
211-
catch err
212-
if err isa MethodError
213-
error(
214-
"Unable to register the function :$name.\n\n" *
215-
_FORWARD_DIFF_METHOD_ERROR_HELPER,
216-
)
217-
end
218-
# We hit some other error, perhaps we called a function like log(-1).
219-
# Ignore for now, and hope that a useful error is shown to the user
220-
# during the solve.
221-
end
222-
return
223-
end
224-
225-
function _checked_derivative(f::F, op::Symbol) where {F}
226-
return function (x)
227-
try
228-
return ForwardDiff.derivative(f, x)
229-
catch err
230-
_intercept_ForwardDiff_MethodError(err, op)
231-
end
232-
end
233-
end
234-
235-
"""
236-
check_return_type(::Type{T}, ret::S) where {T,S}
237-
238-
Overload this method for new types `S` to throw an informative error if a
239-
user-defined function returns the type `S` instead of `T`.
240-
"""
241-
check_return_type(::Type{T}, ret::T) where {T} = nothing
242-
243-
function check_return_type(::Type{T}, ret) where {T}
244-
return error(
245-
"Expected return type of $T from a user-defined function, but got " *
246-
"$(typeof(ret)).",
247-
)
248-
end
249-
250-
struct _UnivariateOperator{F,F′,F′′}
251-
f::F
252-
f′::F′
253-
f′′::F′′
254-
function _UnivariateOperator(
255-
f::Function,
256-
f′::Function,
257-
f′′::Union{Nothing,Function} = nothing,
258-
)
259-
return new{typeof(f),typeof(f′),typeof(f′′)}(f, f′, f′′)
260-
end
261-
end
262-
263-
function _UnivariateOperator(op::Symbol, f::Function)
264-
_validate_register_assumptions(f, op, 1)
265-
f′ = _checked_derivative(f, op)
266-
return _UnivariateOperator(op, f, f′)
267-
end
268-
269-
function _UnivariateOperator(op::Symbol, f::Function, f′::Function)
270-
try
271-
_validate_register_assumptions(f′, op, 1)
272-
f′′ = _checked_derivative(f′, op)
273-
return _UnivariateOperator(f, f′, f′′)
274-
catch
275-
return _UnivariateOperator(f, f′, nothing)
276-
end
277-
end
278-
279-
function _UnivariateOperator(::Symbol, f::Function, f′::Function, f′′::Function)
280-
return _UnivariateOperator(f, f′, f′′)
281-
end
282-
283339
function eval_univariate_function(operator::_UnivariateOperator, x::T) where {T}
284340
ret = operator.f(x)
285341
check_return_type(T, ret)

src/types.jl

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -133,57 +133,6 @@ struct _FunctionStorage
133133
end
134134
end
135135

136-
struct OperatorRegistry
137-
# NODE_CALL_UNIVARIATE
138-
univariate_operators::Vector{Symbol}
139-
univariate_operator_to_id::Dict{Symbol,Int}
140-
univariate_user_operator_start::Int
141-
registered_univariate_operators::Vector{_UnivariateOperator}
142-
# NODE_CALL_MULTIVARIATE
143-
multivariate_operators::Vector{Symbol}
144-
multivariate_operator_to_id::Dict{Symbol,Int}
145-
multivariate_user_operator_start::Int
146-
registered_multivariate_operators::Vector{
147-
MOI.Nonlinear._MultivariateOperator,
148-
}
149-
# NODE_LOGIC
150-
logic_operators::Vector{Symbol}
151-
logic_operator_to_id::Dict{Symbol,Int}
152-
# NODE_COMPARISON
153-
comparison_operators::Vector{Symbol}
154-
comparison_operator_to_id::Dict{Symbol,Int}
155-
function OperatorRegistry()
156-
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
157-
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
158-
logic_operators = [:&&, :||]
159-
comparison_operators = [:<=, :(==), :>=, :<, :>]
160-
return new(
161-
# NODE_CALL_UNIVARIATE
162-
univariate_operators,
163-
Dict{Symbol,Int}(
164-
op => i for (i, op) in enumerate(univariate_operators)
165-
),
166-
length(univariate_operators),
167-
_UnivariateOperator[],
168-
# NODE_CALL
169-
multivariate_operators,
170-
Dict{Symbol,Int}(
171-
op => i for (i, op) in enumerate(multivariate_operators)
172-
),
173-
length(multivariate_operators),
174-
MOI.Nonlinear._MultivariateOperator[],
175-
# NODE_LOGIC
176-
logic_operators,
177-
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
178-
# NODE_COMPARISON
179-
comparison_operators,
180-
Dict{Symbol,Int}(
181-
op => i for (i, op) in enumerate(comparison_operators)
182-
),
183-
)
184-
end
185-
end
186-
187136
"""
188137
Model()
189138

0 commit comments

Comments
 (0)