Skip to content

Commit 16140e6

Browse files
Update test functions to produce NaN/Inf gradients naturally
- Use sqrt and max to produce NaN when x goes negative - Use 1/x pattern to produce Inf gradients - Functions naturally produce problematic gradients during optimization
1 parent 9ae7839 commit 16140e6

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,15 @@ end
137137

138138
@testset "NaN/Inf gradient handling" begin
139139
# Test that optimizer skips updates when gradients contain NaN or Inf
140-
# Function that returns NaN when parameters are in certain regions
140+
# Function that can produce NaN due to sqrt of negative number
141141
function weird_nan_function(x, p)
142-
# Return NaN when x[1] is close to certain values to simulate numerical issues
143-
if abs(x[1] - 0.3) < 0.05 || abs(x[1] + 0.3) < 0.05
144-
return NaN
145-
end
146-
return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
142+
val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
143+
# sqrt of a value that can become negative produces NaN
144+
val += sqrt(max(x[1], 0.0)) * 0.01
145+
return val
147146
end
148147

149-
x0 = zeros(2)
148+
x0 = [-0.5, 0.1] # Start with negative x[1] to trigger sqrt of negative
150149
_p = [1.0, 100.0]
151150

152151
optprob = OptimizationFunction(weird_nan_function, OptimizationBase.AutoZygote())
@@ -162,13 +161,12 @@ end
162161
@test all(!isnan, sol.u)
163162
@test all(isfinite, sol.u)
164163

165-
# Function that returns Inf when parameters are in certain regions
164+
# Function with 1/x that can produce Inf gradient when x is very small
166165
function weird_inf_function(x, p)
167-
# Return Inf when x[1] is close to certain values
168-
if abs(x[1] - 0.2) < 0.05 || abs(x[1] + 0.2) < 0.05
169-
return Inf
170-
end
171-
return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
166+
val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
167+
# 1/(x[1] + 0.01) can have very large gradient near x[1] = -0.01
168+
val += 0.01 / (abs(x[1] - 0.1) + 1e-8)
169+
return val
172170
end
173171

174172
optprob_inf = OptimizationFunction(weird_inf_function, OptimizationBase.AutoZygote())

0 commit comments

Comments
 (0)