Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
name = "ArrayDiff"
uuid = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"]
version = "0.1.0"
authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"]

[deps]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Calculus = "0.5.2"
DataStructures = "0.18, 0.19"
ForwardDiff = "1"
MathOptInterface = "1.40"
NaNMath = "1"
SparseArrays = "1.10"
SpecialFunctions = "2.6.1"
julia = "1.10"
3 changes: 2 additions & 1 deletion src/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ import NaNMath:
include("Coloring/Coloring.jl")
include("graph_tools.jl")
include("sizes.jl")
include("univariate_expressions.jl")
include("operators.jl")
include("types.jl")
include("utils.jl")

include("reverse_mode.jl")
include("forward_over_reverse.jl")
include("mathoptinterface_api.jl")
include("operators.jl")
include("model.jl")
include("parse.jl")
include("evaluator.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/mathoptinterface_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

_no_hessian(op::MOI.Nonlinear._UnivariateOperator) = op.f′′ === nothing
_no_hessian(op::_UnivariateOperator) = op.f′′ === nothing
_no_hessian(op::MOI.Nonlinear._MultivariateOperator) = op.∇²f === nothing

function MOI.features_available(d::NLPEvaluator)
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function register_operator(
elseif haskey(registry.multivariate_operator_to_id, op)
error("Operator $op is already registered.")
end
operator = Nonlinear._UnivariateOperator(op, f...)
operator = _UnivariateOperator(op, f...)
push!(registry.univariate_operators, op)
push!(registry.registered_univariate_operators, operator)
registry.univariate_operator_to_id[op] =
Expand Down
202 changes: 200 additions & 2 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,160 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
:row,
]

function _validate_register_assumptions(
f::Function,
name::Symbol,
dimension::Integer,
)
# Assumption 1: check that `f` can be called with `Float64` arguments.
y = 0.0
try
Copy link
Owner

@blegat blegat Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use hasmethod(f, (Array{dimension,Float64})) and hasmethod(f, Float64):

julia> hasmethod(size, (Array{3,Int},))
true

julia> hasmethod(log, (Array{3,Int},))
false

if dimension == 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For scalars, dimension should be 0 I think

Copy link
Collaborator Author

@SophieL1 SophieL1 Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we decide. I copy-pasted it from Nonlinear as it was needed in _UnivariateOperator which I also copy-pasted. It's the only place it is used for now, and dimension 1 is hard-coded for saying 'there's one argument' I guess.
I changed it to 0 as it may make more sense if we have vectors in the future.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, renaming nb_args sounds good :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sorry I've changed my mind; I put nb_args=1 indeed to make it clear :)

y = f(0.0)
else
y = f(zeros(dimension)...)
end
catch
# We hit some other error, perhaps we called a function like log(-1).
# Ignore for now, and hope that a useful error is shown to the user
# during the solve.
end
if !(y isa Real)
error(
"Expected return type of `Float64` from the user-defined " *
"function :$(name), but got `$(typeof(y))`.",
)
end
# Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
try
if dimension == 1
ForwardDiff.derivative(f, 0.0)
else
ForwardDiff.gradient(x -> f(x...), zeros(dimension))
end
catch err
if err isa MethodError
error(
"Unable to register the function :$name.\n\n" *
_FORWARD_DIFF_METHOD_ERROR_HELPER,
)
end
# We hit some other error, perhaps we called a function like log(-1).
# Ignore for now, and hope that a useful error is shown to the user
# during the solve.
end
return
end

function _checked_derivative(f::F, op::Symbol) where {F}
return function (x)
try
return ForwardDiff.derivative(f, x)
catch err
_intercept_ForwardDiff_MethodError(err, op)
end
end
end

"""
check_return_type(::Type{T}, ret::S) where {T,S}

Overload this method for new types `S` to throw an informative error if a
user-defined function returns the type `S` instead of `T`.
"""
check_return_type(::Type{T}, ret::T) where {T} = nothing

function check_return_type(::Type{T}, ret) where {T}
return error(
"Expected return type of $T from a user-defined function, but got " *
"$(typeof(ret)).",
)
end

struct _UnivariateOperator{F,F′,F′′}
f::F
f′::F′
f′′::F′′
function _UnivariateOperator(
f::Function,
f′::Function,
f′′::Union{Nothing,Function} = nothing,
)
return new{typeof(f),typeof(f′),typeof(f′′)}(f, f′, f′′)
end
end

function _UnivariateOperator(op::Symbol, f::Function)
_validate_register_assumptions(f, op, 1)
f′ = _checked_derivative(f, op)
return _UnivariateOperator(op, f, f′)
end

function _UnivariateOperator(op::Symbol, f::Function, f′::Function)
try
_validate_register_assumptions(f′, op, 1)
f′′ = _checked_derivative(f′, op)
return _UnivariateOperator(f, f′, f′′)
catch
return _UnivariateOperator(f, f′, nothing)
end
end

function _UnivariateOperator(::Symbol, f::Function, f′::Function, f′′::Function)
return _UnivariateOperator(f, f′, f′′)
end

struct OperatorRegistry
# NODE_CALL_UNIVARIATE
univariate_operators::Vector{Symbol}
univariate_operator_to_id::Dict{Symbol,Int}
univariate_user_operator_start::Int
registered_univariate_operators::Vector{_UnivariateOperator}
# NODE_CALL_MULTIVARIATE
multivariate_operators::Vector{Symbol}
multivariate_operator_to_id::Dict{Symbol,Int}
multivariate_user_operator_start::Int
registered_multivariate_operators::Vector{
MOI.Nonlinear._MultivariateOperator,
}
# NODE_LOGIC
logic_operators::Vector{Symbol}
logic_operator_to_id::Dict{Symbol,Int}
# NODE_COMPARISON
comparison_operators::Vector{Symbol}
comparison_operator_to_id::Dict{Symbol,Int}
function OperatorRegistry()
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
logic_operators = [:&&, :||]
comparison_operators = [:<=, :(==), :>=, :<, :>]
return new(
# NODE_CALL_UNIVARIATE
univariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(univariate_operators)
),
length(univariate_operators),
_UnivariateOperator[],
# NODE_CALL
multivariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(multivariate_operators)
),
length(multivariate_operators),
MOI.Nonlinear._MultivariateOperator[],
# NODE_LOGIC
logic_operators,
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
# NODE_COMPARISON
comparison_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(comparison_operators)
),
)
end
end

function eval_logic_function(
::OperatorRegistry,
op::Symbol,
Expand All @@ -34,6 +188,23 @@ function eval_logic_function(
end
end

function _generate_eval_univariate()
exprs = map(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS) do op
return :(
return (
value_deriv_and_second($op, x)[1],
value_deriv_and_second($op, x)[2],
)
)
end
return Nonlinear._create_binary_switch(1:length(exprs), exprs)
end

@eval @inline function _eval_univariate(id, x::T) where {T}
$(_generate_eval_univariate())
return error("Invalid id for univariate operator: $id")
end

function eval_multivariate_function(
registry::OperatorRegistry,
op::Symbol,
Expand Down Expand Up @@ -165,17 +336,44 @@ function eval_multivariate_hessian(
return true
end

function eval_univariate_function(operator::_UnivariateOperator, x::T) where {T}
ret = operator.f(x)
check_return_type(T, ret)
return ret::T
end

function eval_univariate_gradient(operator::_UnivariateOperator, x::T) where {T}
ret = operator.f′(x)
check_return_type(T, ret)
return ret::T
end

function eval_univariate_hessian(operator::_UnivariateOperator, x::T) where {T}
ret = operator.f′′(x)
check_return_type(T, ret)
return ret::T
end

function eval_univariate_function_and_gradient(
operator::_UnivariateOperator,
x::T,
) where {T}
ret_f = eval_univariate_function(operator, x)
ret_f′ = eval_univariate_gradient(operator, x)
return ret_f, ret_f′
end

function eval_univariate_function_and_gradient(
registry::OperatorRegistry,
id::Integer,
x::T,
) where {T}
if id <= registry.univariate_user_operator_start
return Nonlinear._eval_univariate(id, x)::Tuple{T,T}
return _eval_univariate(id, x)::Tuple{T,T}
end
offset = id - registry.univariate_user_operator_start
operator = registry.registered_univariate_operators[offset]
return Nonlinear.eval_univariate_function_and_gradient(operator, x)
return eval_univariate_function_and_gradient(operator, x)
end

function eval_multivariate_gradient(
Expand Down
51 changes: 0 additions & 51 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,57 +133,6 @@ struct _FunctionStorage
end
end

struct OperatorRegistry
# NODE_CALL_UNIVARIATE
univariate_operators::Vector{Symbol}
univariate_operator_to_id::Dict{Symbol,Int}
univariate_user_operator_start::Int
registered_univariate_operators::Vector{MOI.Nonlinear._UnivariateOperator}
# NODE_CALL_MULTIVARIATE
multivariate_operators::Vector{Symbol}
multivariate_operator_to_id::Dict{Symbol,Int}
multivariate_user_operator_start::Int
registered_multivariate_operators::Vector{
MOI.Nonlinear._MultivariateOperator,
}
# NODE_LOGIC
logic_operators::Vector{Symbol}
logic_operator_to_id::Dict{Symbol,Int}
# NODE_COMPARISON
comparison_operators::Vector{Symbol}
comparison_operator_to_id::Dict{Symbol,Int}
function OperatorRegistry()
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
logic_operators = [:&&, :||]
comparison_operators = [:<=, :(==), :>=, :<, :>]
return new(
# NODE_CALL_UNIVARIATE
univariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(univariate_operators)
),
length(univariate_operators),
MOI.Nonlinear._UnivariateOperator[],
# NODE_CALL
multivariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(multivariate_operators)
),
length(multivariate_operators),
MOI.Nonlinear._MultivariateOperator[],
# NODE_LOGIC
logic_operators,
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
# NODE_COMPARISON
comparison_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(comparison_operators)
),
)
end
end

"""
Model()

Expand Down
Loading
Loading