Skip to content

Commit 68dbc68

Browse files
authored
Merge pull request #26 from blegat/sl/ChainRules
Change univariate_expressions (tuples instead of functions)
2 parents 6caa6c9 + 9dd38d5 commit 68dbc68

File tree

10 files changed

+386
-59
lines changed

10 files changed

+386
-59
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
name = "ArrayDiff"
22
uuid = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
3-
authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"]
43
version = "0.1.0"
4+
authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"]
55

66
[deps]
7+
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
78
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
89
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
910
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1011
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
11-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1212
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
13+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
14+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1315

1416
[compat]
17+
Calculus = "0.5.2"
1518
DataStructures = "0.18, 0.19"
1619
ForwardDiff = "1"
1720
MathOptInterface = "1.40"
1821
NaNMath = "1"
1922
SparseArrays = "1.10"
23+
SpecialFunctions = "2.6.1"
2024
julia = "1.10"

src/ArrayDiff.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@ import NaNMath:
3636
include("Coloring/Coloring.jl")
3737
include("graph_tools.jl")
3838
include("sizes.jl")
39+
include("univariate_expressions.jl")
40+
include("operators.jl")
3941
include("types.jl")
4042
include("utils.jl")
4143

4244
include("reverse_mode.jl")
4345
include("forward_over_reverse.jl")
4446
include("mathoptinterface_api.jl")
45-
include("operators.jl")
4647
include("model.jl")
4748
include("parse.jl")
4849
include("evaluator.jl")

src/mathoptinterface_api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Use of this source code is governed by an MIT-style license that can be found
55
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
66

7-
_no_hessian(op::MOI.Nonlinear._UnivariateOperator) = op.f′′ === nothing
7+
_no_hessian(op::_UnivariateOperator) = op.f′′ === nothing
88
_no_hessian(op::MOI.Nonlinear._MultivariateOperator) = op.∇²f === nothing
99

1010
function MOI.features_available(d::NLPEvaluator)

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function register_operator(
5353
elseif haskey(registry.multivariate_operator_to_id, op)
5454
error("Operator $op is already registered.")
5555
end
56-
operator = Nonlinear._UnivariateOperator(op, f...)
56+
operator = _UnivariateOperator(op, f...)
5757
push!(registry.univariate_operators, op)
5858
push!(registry.registered_univariate_operators, operator)
5959
registry.univariate_operator_to_id[op] =

src/operators.jl

Lines changed: 198 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,158 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
2020
:row,
2121
]
2222

23+
function _validate_register_assumptions(
24+
f::Function,
25+
name::Symbol,
26+
nb_args::Integer,
27+
)
28+
# Assumption 1: check that `f` can be called with `Float64` arguments.
29+
arg = nb_args == 1 ? 0.0 : zeros(nb_args)
30+
if hasmethod(f, Tuple{typeof(arg)})
31+
y = f(arg)
32+
else
33+
error(
34+
"Unable to register the function :$name.\n\n" *
35+
"The function must be able to be called with $nb_args Float64 " *
36+
"arguments, but no method was found for this.",
37+
)
38+
end
39+
if !(y isa Real)
40+
error(
41+
"Expected return type of `Float64` from the user-defined " *
42+
"function :$(name), but got `$(typeof(y))`.",
43+
)
44+
end
45+
# Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
46+
try
47+
if nb_args == 1
48+
ForwardDiff.derivative(f, 0.0)
49+
else
50+
ForwardDiff.gradient(x -> f(x...), zeros(nb_args))
51+
end
52+
catch err
53+
if err isa MethodError
54+
error(
55+
"Unable to register the function :$name.\n\n" *
56+
_FORWARD_DIFF_METHOD_ERROR_HELPER,
57+
)
58+
end
59+
# We hit some other error, perhaps we called a function like log(-1).
60+
# Ignore for now, and hope that a useful error is shown to the user
61+
# during the solve.
62+
end
63+
return
64+
end
65+
66+
function _checked_derivative(f::F, op::Symbol) where {F}
67+
return function (x)
68+
try
69+
return ForwardDiff.derivative(f, x)
70+
catch err
71+
_intercept_ForwardDiff_MethodError(err, op)
72+
end
73+
end
74+
end
75+
76+
"""
77+
check_return_type(::Type{T}, ret::S) where {T,S}
78+
79+
Overload this method for new types `S` to throw an informative error if a
80+
user-defined function returns the type `S` instead of `T`.
81+
"""
82+
check_return_type(::Type{T}, ret::T) where {T} = nothing
83+
84+
function check_return_type(::Type{T}, ret) where {T}
85+
return error(
86+
"Expected return type of $T from a user-defined function, but got " *
87+
"$(typeof(ret)).",
88+
)
89+
end
90+
91+
struct _UnivariateOperator{F,F′,F′′}
92+
f::F
93+
f′::F′
94+
f′′::F′′
95+
function _UnivariateOperator(
96+
f::Function,
97+
f′::Function,
98+
f′′::Union{Nothing,Function} = nothing,
99+
)
100+
return new{typeof(f),typeof(f′),typeof(f′′)}(f, f′, f′′)
101+
end
102+
end
103+
104+
function _UnivariateOperator(op::Symbol, f::Function)
105+
_validate_register_assumptions(f, op, 1)
106+
f′ = _checked_derivative(f, op)
107+
return _UnivariateOperator(op, f, f′)
108+
end
109+
110+
function _UnivariateOperator(op::Symbol, f::Function, f′::Function)
111+
try
112+
_validate_register_assumptions(f′, op, 1)
113+
f′′ = _checked_derivative(f′, op)
114+
return _UnivariateOperator(f, f′, f′′)
115+
catch
116+
return _UnivariateOperator(f, f′, nothing)
117+
end
118+
end
119+
120+
function _UnivariateOperator(::Symbol, f::Function, f′::Function, f′′::Function)
121+
return _UnivariateOperator(f, f′, f′′)
122+
end
123+
124+
struct OperatorRegistry
125+
# NODE_CALL_UNIVARIATE
126+
univariate_operators::Vector{Symbol}
127+
univariate_operator_to_id::Dict{Symbol,Int}
128+
univariate_user_operator_start::Int
129+
registered_univariate_operators::Vector{_UnivariateOperator}
130+
# NODE_CALL_MULTIVARIATE
131+
multivariate_operators::Vector{Symbol}
132+
multivariate_operator_to_id::Dict{Symbol,Int}
133+
multivariate_user_operator_start::Int
134+
registered_multivariate_operators::Vector{
135+
MOI.Nonlinear._MultivariateOperator,
136+
}
137+
# NODE_LOGIC
138+
logic_operators::Vector{Symbol}
139+
logic_operator_to_id::Dict{Symbol,Int}
140+
# NODE_COMPARISON
141+
comparison_operators::Vector{Symbol}
142+
comparison_operator_to_id::Dict{Symbol,Int}
143+
function OperatorRegistry()
144+
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
145+
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
146+
logic_operators = [:&&, :||]
147+
comparison_operators = [:<=, :(==), :>=, :<, :>]
148+
return new(
149+
# NODE_CALL_UNIVARIATE
150+
univariate_operators,
151+
Dict{Symbol,Int}(
152+
op => i for (i, op) in enumerate(univariate_operators)
153+
),
154+
length(univariate_operators),
155+
_UnivariateOperator[],
156+
# NODE_CALL
157+
multivariate_operators,
158+
Dict{Symbol,Int}(
159+
op => i for (i, op) in enumerate(multivariate_operators)
160+
),
161+
length(multivariate_operators),
162+
MOI.Nonlinear._MultivariateOperator[],
163+
# NODE_LOGIC
164+
logic_operators,
165+
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
166+
# NODE_COMPARISON
167+
comparison_operators,
168+
Dict{Symbol,Int}(
169+
op => i for (i, op) in enumerate(comparison_operators)
170+
),
171+
)
172+
end
173+
end
174+
23175
function eval_logic_function(
24176
::OperatorRegistry,
25177
op::Symbol,
@@ -34,6 +186,23 @@ function eval_logic_function(
34186
end
35187
end
36188

189+
function _generate_eval_univariate()
190+
exprs = map(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS) do op
191+
return :(
192+
return (
193+
value_deriv_and_second($op, x)[1],
194+
value_deriv_and_second($op, x)[2],
195+
)
196+
)
197+
end
198+
return Nonlinear._create_binary_switch(1:length(exprs), exprs)
199+
end
200+
201+
@eval @inline function _eval_univariate(id, x::T) where {T}
202+
$(_generate_eval_univariate())
203+
return error("Invalid id for univariate operator: $id")
204+
end
205+
37206
function eval_multivariate_function(
38207
registry::OperatorRegistry,
39208
op::Symbol,
@@ -165,17 +334,44 @@ function eval_multivariate_hessian(
165334
return true
166335
end
167336

337+
function eval_univariate_function(operator::_UnivariateOperator, x::T) where {T}
338+
ret = operator.f(x)
339+
check_return_type(T, ret)
340+
return ret::T
341+
end
342+
343+
function eval_univariate_gradient(operator::_UnivariateOperator, x::T) where {T}
344+
ret = operator.f′(x)
345+
check_return_type(T, ret)
346+
return ret::T
347+
end
348+
349+
function eval_univariate_hessian(operator::_UnivariateOperator, x::T) where {T}
350+
ret = operator.f′′(x)
351+
check_return_type(T, ret)
352+
return ret::T
353+
end
354+
355+
function eval_univariate_function_and_gradient(
356+
operator::_UnivariateOperator,
357+
x::T,
358+
) where {T}
359+
ret_f = eval_univariate_function(operator, x)
360+
ret_f′ = eval_univariate_gradient(operator, x)
361+
return ret_f, ret_f′
362+
end
363+
168364
function eval_univariate_function_and_gradient(
169365
registry::OperatorRegistry,
170366
id::Integer,
171367
x::T,
172368
) where {T}
173369
if id <= registry.univariate_user_operator_start
174-
return Nonlinear._eval_univariate(id, x)::Tuple{T,T}
370+
return _eval_univariate(id, x)::Tuple{T,T}
175371
end
176372
offset = id - registry.univariate_user_operator_start
177373
operator = registry.registered_univariate_operators[offset]
178-
return Nonlinear.eval_univariate_function_and_gradient(operator, x)
374+
return eval_univariate_function_and_gradient(operator, x)
179375
end
180376

181377
function eval_multivariate_gradient(

src/types.jl

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -133,57 +133,6 @@ struct _FunctionStorage
133133
end
134134
end
135135

136-
struct OperatorRegistry
137-
# NODE_CALL_UNIVARIATE
138-
univariate_operators::Vector{Symbol}
139-
univariate_operator_to_id::Dict{Symbol,Int}
140-
univariate_user_operator_start::Int
141-
registered_univariate_operators::Vector{MOI.Nonlinear._UnivariateOperator}
142-
# NODE_CALL_MULTIVARIATE
143-
multivariate_operators::Vector{Symbol}
144-
multivariate_operator_to_id::Dict{Symbol,Int}
145-
multivariate_user_operator_start::Int
146-
registered_multivariate_operators::Vector{
147-
MOI.Nonlinear._MultivariateOperator,
148-
}
149-
# NODE_LOGIC
150-
logic_operators::Vector{Symbol}
151-
logic_operator_to_id::Dict{Symbol,Int}
152-
# NODE_COMPARISON
153-
comparison_operators::Vector{Symbol}
154-
comparison_operator_to_id::Dict{Symbol,Int}
155-
function OperatorRegistry()
156-
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
157-
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
158-
logic_operators = [:&&, :||]
159-
comparison_operators = [:<=, :(==), :>=, :<, :>]
160-
return new(
161-
# NODE_CALL_UNIVARIATE
162-
univariate_operators,
163-
Dict{Symbol,Int}(
164-
op => i for (i, op) in enumerate(univariate_operators)
165-
),
166-
length(univariate_operators),
167-
MOI.Nonlinear._UnivariateOperator[],
168-
# NODE_CALL
169-
multivariate_operators,
170-
Dict{Symbol,Int}(
171-
op => i for (i, op) in enumerate(multivariate_operators)
172-
),
173-
length(multivariate_operators),
174-
MOI.Nonlinear._MultivariateOperator[],
175-
# NODE_LOGIC
176-
logic_operators,
177-
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
178-
# NODE_COMPARISON
179-
comparison_operators,
180-
Dict{Symbol,Int}(
181-
op => i for (i, op) in enumerate(comparison_operators)
182-
),
183-
)
184-
end
185-
end
186-
187136
"""
188137
Model()
189138

0 commit comments

Comments
 (0)