Skip to content

Commit 9dd38d5

Browse files
committed
Put dim back to 1 actually and use hasmethod
Dimension in this function represents the number of arguments the function has
1 parent 3f12765 commit 9dd38d5

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

src/operators.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,18 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
2323
function _validate_register_assumptions(
2424
f::Function,
2525
name::Symbol,
26-
dimension::Integer,
26+
nb_args::Integer,
2727
)
2828
# Assumption 1: check that `f` can be called with `Float64` arguments.
29-
y = 0.0
30-
try
31-
if dimension == 0
32-
y = f(0.0)
33-
else
34-
y = f(zeros(dimension)...)
35-
end
36-
catch
37-
# We hit some other error, perhaps we called a function like log(-1).
38-
# Ignore for now, and hope that a useful error is shown to the user
39-
# during the solve.
29+
arg = nb_args == 1 ? 0.0 : zeros(nb_args)
30+
if hasmethod(f, Tuple{typeof(arg)})
31+
y = f(arg)
32+
else
33+
error(
34+
"Unable to register the function :$name.\n\n" *
35+
"The function must be able to be called with $nb_args Float64 " *
36+
"arguments, but no method was found for this.",
37+
)
4038
end
4139
if !(y isa Real)
4240
error(
@@ -46,10 +44,10 @@ function _validate_register_assumptions(
4644
end
4745
# Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
4846
try
49-
if dimension == 0
47+
if nb_args == 1
5048
ForwardDiff.derivative(f, 0.0)
5149
else
52-
ForwardDiff.gradient(x -> f(x...), zeros(dimension))
50+
ForwardDiff.gradient(x -> f(x...), zeros(nb_args))
5351
end
5452
catch err
5553
if err isa MethodError
@@ -104,14 +102,14 @@ struct _UnivariateOperator{F,F′,F′′}
104102
end
105103

106104
function _UnivariateOperator(op::Symbol, f::Function)
107-
_validate_register_assumptions(f, op, 0)
105+
_validate_register_assumptions(f, op, 1)
108106
f′ = _checked_derivative(f, op)
109107
return _UnivariateOperator(op, f, f′)
110108
end
111109

112110
function _UnivariateOperator(op::Symbol, f::Function, f′::Function)
113111
try
114-
_validate_register_assumptions(f′, op, 0)
112+
_validate_register_assumptions(f′, op, 1)
115113
f′′ = _checked_derivative(f′, op)
116114
return _UnivariateOperator(f, f′, f′′)
117115
catch

0 commit comments

Comments
 (0)