@@ -23,20 +23,18 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
2323function _validate_register_assumptions (
2424 f:: Function ,
2525 name:: Symbol ,
26- dimension :: Integer ,
26+ nb_args :: Integer ,
2727)
2828 # Assumption 1: check that `f` can be called with `Float64` arguments.
29- y = 0.0
30- try
31- if dimension == 0
32- y = f (0.0 )
33- else
34- y = f (zeros (dimension)... )
35- end
36- catch
37- # We hit some other error, perhaps we called a function like log(-1).
38- # Ignore for now, and hope that a useful error is shown to the user
39- # during the solve.
29+ arg = nb_args == 1 ? 0.0 : zeros (nb_args)
30+ if hasmethod (f, Tuple{typeof (arg)})
31+ y = f (arg)
32+ else
33+ error (
34+ " Unable to register the function :$name .\n\n " *
35+ " The function must be able to be called with $nb_args Float64 " *
36+ " arguments, but no method was found for this." ,
37+ )
4038 end
4139 if ! (y isa Real)
4240 error (
@@ -46,10 +44,10 @@ function _validate_register_assumptions(
4644 end
4745 # Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
4846 try
49- if dimension == 0
47+ if nb_args == 1
5048 ForwardDiff. derivative (f, 0.0 )
5149 else
52- ForwardDiff. gradient (x -> f (x... ), zeros (dimension ))
50+ ForwardDiff. gradient (x -> f (x... ), zeros (nb_args ))
5351 end
5452 catch err
5553 if err isa MethodError
@@ -104,14 +102,14 @@ struct _UnivariateOperator{F,F′,F′′}
104102end
105103
106104function _UnivariateOperator (op:: Symbol , f:: Function )
107- _validate_register_assumptions (f, op, 0 )
105+ _validate_register_assumptions (f, op, 1 )
108106 f′ = _checked_derivative (f, op)
109107 return _UnivariateOperator (op, f, f′)
110108end
111109
112110function _UnivariateOperator (op:: Symbol , f:: Function , f′:: Function )
113111 try
114- _validate_register_assumptions (f′, op, 0 )
112+ _validate_register_assumptions (f′, op, 1 )
115113 f′′ = _checked_derivative (f′, op)
116114 return _UnivariateOperator (f, f′, f′′)
117115 catch
0 commit comments