Skip to content

Commit 0b17bc9

Browse files
Add tests for NaN/Inf gradient handling
- Test with custom gradient function that injects NaN periodically - Test with custom gradient function that injects Inf periodically - Verify iterations complete and parameters remain finite - Verify optimizer doesn't crash when encountering bad gradients
1 parent 7957bd6 commit 0b17bc9

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,64 @@ end
134134

135135
@test res.objective < 1e-4
136136
end
137+
138+
@testset "NaN/Inf gradient handling" begin
139+
# Test that optimizer skips updates when gradients contain NaN or Inf
140+
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
141+
x0 = zeros(2)
142+
_p = [1.0, 100.0]
143+
144+
# Counter to track gradient evaluations
145+
grad_counter = Ref(0)
146+
147+
# Custom gradient function that returns NaN on every 5th call
148+
function custom_grad!(G, x, p)
149+
grad_counter[] += 1
150+
if grad_counter[] % 5 == 0
151+
# Inject NaN into gradient
152+
G .= NaN
153+
else
154+
# Normal gradient computation
155+
G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2)
156+
G[2] = 2.0 * p[2] * (x[2] - x[1]^2)
157+
end
158+
return nothing
159+
end
160+
161+
optprob = OptimizationFunction(rosenbrock; grad = custom_grad!)
162+
prob = OptimizationProblem(optprob, x0, _p)
163+
164+
# Should not throw error and should complete all iterations
165+
sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false)
166+
167+
# Verify solution completed all iterations
168+
@test sol.stats.iterations == 20
169+
170+
# Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients)
171+
@test all(!isnan, sol.u)
172+
@test all(isfinite, sol.u)
173+
174+
# Test with Inf gradients
175+
grad_counter_inf = Ref(0)
176+
function custom_grad_inf!(G, x, p)
177+
grad_counter_inf[] += 1
178+
if grad_counter_inf[] % 7 == 0
179+
# Inject Inf into gradient
180+
G .= Inf
181+
else
182+
# Normal gradient computation
183+
G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2)
184+
G[2] = 2.0 * p[2] * (x[2] - x[1]^2)
185+
end
186+
return nothing
187+
end
188+
189+
optprob_inf = OptimizationFunction(rosenbrock; grad = custom_grad_inf!)
190+
prob_inf = OptimizationProblem(optprob_inf, x0, _p)
191+
192+
sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false)
193+
194+
@test sol_inf.stats.iterations == 20
195+
@test all(!isnan, sol_inf.u)
196+
@test all(isfinite, sol_inf.u)
197+
end

0 commit comments

Comments
 (0)