Skip to content

Commit 3f12765

Browse files
committed
Put dim to 0 when scalar in some function
See comment by BL on PR, arbitrary choice for now
1 parent 9b7c3aa commit 3f12765

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/operators.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function _validate_register_assumptions(
2828
# Assumption 1: check that `f` can be called with `Float64` arguments.
2929
y = 0.0
3030
try
31-
if dimension == 1
31+
if dimension == 0
3232
y = f(0.0)
3333
else
3434
y = f(zeros(dimension)...)
@@ -46,7 +46,7 @@ function _validate_register_assumptions(
4646
end
4747
# Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
4848
try
49-
if dimension == 1
49+
if dimension == 0
5050
ForwardDiff.derivative(f, 0.0)
5151
else
5252
ForwardDiff.gradient(x -> f(x...), zeros(dimension))
@@ -104,14 +104,14 @@ struct _UnivariateOperator{F,F′,F′′}
104104
end
105105

106106
function _UnivariateOperator(op::Symbol, f::Function)
107-
_validate_register_assumptions(f, op, 1)
107+
_validate_register_assumptions(f, op, 0)
108108
f′ = _checked_derivative(f, op)
109109
return _UnivariateOperator(op, f, f′)
110110
end
111111

112112
function _UnivariateOperator(op::Symbol, f::Function, f′::Function)
113113
try
114-
_validate_register_assumptions(f′, op, 1)
114+
_validate_register_assumptions(f′, op, 0)
115115
f′′ = _checked_derivative(f′, op)
116116
return _UnivariateOperator(f, f′, f′′)
117117
catch

0 commit comments

Comments
 (0)