-
Notifications
You must be signed in to change notification settings - Fork 0
Replace tuples by functions in univariate_expressions #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
e3b3ec8
244f2e1
c4491c7
9b7c3aa
3f12765
9dd38d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,24 @@ | ||
| name = "ArrayDiff" | ||
| uuid = "c45fa1ca-6901-44ac-ae5b-5513a4852d50" | ||
| authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"] | ||
| version = "0.1.0" | ||
| authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"] | ||
|
|
||
| [deps] | ||
| Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" | ||
| DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" | ||
| ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
| MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" | ||
| NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" | ||
| SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
| OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" | ||
| SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
| SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" | ||
|
|
||
| [compat] | ||
| Calculus = "0.5.2" | ||
| DataStructures = "0.18, 0.19" | ||
| ForwardDiff = "1" | ||
| MathOptInterface = "1.40" | ||
| NaNMath = "1" | ||
| SparseArrays = "1.10" | ||
| SpecialFunctions = "2.6.1" | ||
| julia = "1.10" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,160 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [ | |
| :row, | ||
| ] | ||
|
|
||
| function _validate_register_assumptions( | ||
| f::Function, | ||
| name::Symbol, | ||
| dimension::Integer, | ||
| ) | ||
| # Assumption 1: check that `f` can be called with `Float64` arguments. | ||
| y = 0.0 | ||
| try | ||
| if dimension == 1 | ||
|
||
| y = f(0.0) | ||
| else | ||
| y = f(zeros(dimension)...) | ||
| end | ||
| catch | ||
| # We hit some other error, perhaps we called a function like log(-1). | ||
| # Ignore for now, and hope that a useful error is shown to the user | ||
| # during the solve. | ||
| end | ||
| if !(y isa Real) | ||
| error( | ||
| "Expected return type of `Float64` from the user-defined " * | ||
| "function :$(name), but got `$(typeof(y))`.", | ||
| ) | ||
| end | ||
| # Assumption 2: check that `f` can be differentiated using `ForwardDiff`. | ||
| try | ||
| if dimension == 1 | ||
| ForwardDiff.derivative(f, 0.0) | ||
| else | ||
| ForwardDiff.gradient(x -> f(x...), zeros(dimension)) | ||
| end | ||
| catch err | ||
| if err isa MethodError | ||
| error( | ||
| "Unable to register the function :$name.\n\n" * | ||
| _FORWARD_DIFF_METHOD_ERROR_HELPER, | ||
| ) | ||
| end | ||
| # We hit some other error, perhaps we called a function like log(-1). | ||
| # Ignore for now, and hope that a useful error is shown to the user | ||
| # during the solve. | ||
| end | ||
| return | ||
| end | ||
|
|
||
| function _checked_derivative(f::F, op::Symbol) where {F} | ||
| return function (x) | ||
| try | ||
| return ForwardDiff.derivative(f, x) | ||
| catch err | ||
| _intercept_ForwardDiff_MethodError(err, op) | ||
| end | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| check_return_type(::Type{T}, ret::S) where {T,S} | ||
|
|
||
| Overload this method for new types `S` to throw an informative error if a | ||
| user-defined function returns the type `S` instead of `T`. | ||
| """ | ||
| check_return_type(::Type{T}, ret::T) where {T} = nothing | ||
|
|
||
| function check_return_type(::Type{T}, ret) where {T} | ||
| return error( | ||
| "Expected return type of $T from a user-defined function, but got " * | ||
| "$(typeof(ret)).", | ||
| ) | ||
| end | ||
|
|
||
| struct _UnivariateOperator{F,F′,F′′} | ||
| f::F | ||
| f′::F′ | ||
| f′′::F′′ | ||
| function _UnivariateOperator( | ||
| f::Function, | ||
| f′::Function, | ||
| f′′::Union{Nothing,Function} = nothing, | ||
| ) | ||
| return new{typeof(f),typeof(f′),typeof(f′′)}(f, f′, f′′) | ||
| end | ||
| end | ||
|
|
||
| function _UnivariateOperator(op::Symbol, f::Function) | ||
| _validate_register_assumptions(f, op, 1) | ||
| f′ = _checked_derivative(f, op) | ||
| return _UnivariateOperator(op, f, f′) | ||
| end | ||
|
|
||
| function _UnivariateOperator(op::Symbol, f::Function, f′::Function) | ||
| try | ||
| _validate_register_assumptions(f′, op, 1) | ||
| f′′ = _checked_derivative(f′, op) | ||
| return _UnivariateOperator(f, f′, f′′) | ||
| catch | ||
| return _UnivariateOperator(f, f′, nothing) | ||
| end | ||
| end | ||
|
|
||
| function _UnivariateOperator(::Symbol, f::Function, f′::Function, f′′::Function) | ||
| return _UnivariateOperator(f, f′, f′′) | ||
| end | ||
|
|
||
| struct OperatorRegistry | ||
| # NODE_CALL_UNIVARIATE | ||
| univariate_operators::Vector{Symbol} | ||
| univariate_operator_to_id::Dict{Symbol,Int} | ||
| univariate_user_operator_start::Int | ||
| registered_univariate_operators::Vector{_UnivariateOperator} | ||
| # NODE_CALL_MULTIVARIATE | ||
| multivariate_operators::Vector{Symbol} | ||
| multivariate_operator_to_id::Dict{Symbol,Int} | ||
| multivariate_user_operator_start::Int | ||
| registered_multivariate_operators::Vector{ | ||
| MOI.Nonlinear._MultivariateOperator, | ||
| } | ||
| # NODE_LOGIC | ||
| logic_operators::Vector{Symbol} | ||
| logic_operator_to_id::Dict{Symbol,Int} | ||
| # NODE_COMPARISON | ||
| comparison_operators::Vector{Symbol} | ||
| comparison_operator_to_id::Dict{Symbol,Int} | ||
| function OperatorRegistry() | ||
| univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS) | ||
| multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS) | ||
| logic_operators = [:&&, :||] | ||
| comparison_operators = [:<=, :(==), :>=, :<, :>] | ||
| return new( | ||
| # NODE_CALL_UNIVARIATE | ||
| univariate_operators, | ||
| Dict{Symbol,Int}( | ||
| op => i for (i, op) in enumerate(univariate_operators) | ||
| ), | ||
| length(univariate_operators), | ||
| _UnivariateOperator[], | ||
| # NODE_CALL | ||
| multivariate_operators, | ||
| Dict{Symbol,Int}( | ||
| op => i for (i, op) in enumerate(multivariate_operators) | ||
| ), | ||
| length(multivariate_operators), | ||
| MOI.Nonlinear._MultivariateOperator[], | ||
| # NODE_LOGIC | ||
| logic_operators, | ||
| Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)), | ||
| # NODE_COMPARISON | ||
| comparison_operators, | ||
| Dict{Symbol,Int}( | ||
| op => i for (i, op) in enumerate(comparison_operators) | ||
| ), | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| function eval_logic_function( | ||
| ::OperatorRegistry, | ||
| op::Symbol, | ||
|
|
@@ -34,6 +188,23 @@ function eval_logic_function( | |
| end | ||
| end | ||
|
|
||
| function _generate_eval_univariate() | ||
| exprs = map(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS) do op | ||
| return :( | ||
| return ( | ||
| value_deriv_and_second($op, x)[1], | ||
| value_deriv_and_second($op, x)[2], | ||
| ) | ||
| ) | ||
| end | ||
| return Nonlinear._create_binary_switch(1:length(exprs), exprs) | ||
| end | ||
|
|
||
| @eval @inline function _eval_univariate(id, x::T) where {T} | ||
| $(_generate_eval_univariate()) | ||
| return error("Invalid id for univariate operator: $id") | ||
| end | ||
|
|
||
| function eval_multivariate_function( | ||
| registry::OperatorRegistry, | ||
| op::Symbol, | ||
|
|
@@ -165,17 +336,44 @@ function eval_multivariate_hessian( | |
| return true | ||
| end | ||
|
|
||
| function eval_univariate_function(operator::_UnivariateOperator, x::T) where {T} | ||
| ret = operator.f(x) | ||
| check_return_type(T, ret) | ||
| return ret::T | ||
| end | ||
|
|
||
| function eval_univariate_gradient(operator::_UnivariateOperator, x::T) where {T} | ||
| ret = operator.f′(x) | ||
| check_return_type(T, ret) | ||
| return ret::T | ||
| end | ||
|
|
||
| function eval_univariate_hessian(operator::_UnivariateOperator, x::T) where {T} | ||
| ret = operator.f′′(x) | ||
| check_return_type(T, ret) | ||
| return ret::T | ||
| end | ||
|
|
||
| function eval_univariate_function_and_gradient( | ||
| operator::_UnivariateOperator, | ||
| x::T, | ||
| ) where {T} | ||
| ret_f = eval_univariate_function(operator, x) | ||
| ret_f′ = eval_univariate_gradient(operator, x) | ||
| return ret_f, ret_f′ | ||
| end | ||
|
|
||
| function eval_univariate_function_and_gradient( | ||
| registry::OperatorRegistry, | ||
| id::Integer, | ||
| x::T, | ||
| ) where {T} | ||
| if id <= registry.univariate_user_operator_start | ||
| return Nonlinear._eval_univariate(id, x)::Tuple{T,T} | ||
| return _eval_univariate(id, x)::Tuple{T,T} | ||
| end | ||
| offset = id - registry.univariate_user_operator_start | ||
| operator = registry.registered_univariate_operators[offset] | ||
| return Nonlinear.eval_univariate_function_and_gradient(operator, x) | ||
| return eval_univariate_function_and_gradient(operator, x) | ||
| end | ||
|
|
||
| function eval_multivariate_gradient( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use
hasmethod(f, (Array{dimension,Float64}))andhasmethod(f, Float64):