Skip to content

Commit cc69c5d

Browse files
Fix test to use callback approach for NaN/Inf injection
- Use Zygote for gradient computation - Inject NaN/Inf via callback that modifies state.grad - This better simulates real-world scenarios where autodiff produces NaN - Avoids issues with custom gradient function signatures
1 parent 0cf9eec commit cc69c5d

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -141,28 +141,26 @@ end
141141
x0 = zeros(2)
142142
_p = [1.0, 100.0]
143143

144-
# Counter to track gradient evaluations
144+
# Test with NaN gradients using Zygote
145+
# We'll use a callback to inject NaN into some iterations
145146
grad_counter = Ref(0)
146147

147-
# Custom gradient function that returns NaN on every 5th call
148-
function custom_grad!(G, x, p)
148+
# Create optimization problem with automatic differentiation
149+
optprob = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote())
150+
prob = OptimizationProblem(optprob, x0, _p)
151+
152+
# Use a callback that modifies the gradient to inject NaN periodically
153+
function nan_callback(state, l)
149154
grad_counter[] += 1
150155
if grad_counter[] % 5 == 0
151-
# Inject NaN into gradient
152-
G .= NaN
153-
else
154-
# Normal gradient computation
155-
G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2)
156-
G[2] = 2.0 * p[2] * (x[2] - x[1]^2)
156+
# Inject NaN into gradient on every 5th iteration
157+
state.grad .= NaN
157158
end
158-
return nothing
159+
return false
159160
end
160161

161-
optprob = OptimizationFunction(rosenbrock; grad = custom_grad!)
162-
prob = OptimizationProblem(optprob, x0, _p)
163-
164162
# Should not throw error and should complete all iterations
165-
sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false)
163+
sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = nan_callback)
166164

167165
# Verify solution completed all iterations
168166
@test sol.stats.iterations == 20
@@ -173,23 +171,18 @@ end
173171

174172
# Test with Inf gradients
175173
grad_counter_inf = Ref(0)
176-
function custom_grad_inf!(G, x, p)
174+
prob_inf = OptimizationProblem(optprob, x0, _p)
175+
176+
function inf_callback(state, l)
177177
grad_counter_inf[] += 1
178178
if grad_counter_inf[] % 7 == 0
179-
# Inject Inf into gradient
180-
G .= Inf
181-
else
182-
# Normal gradient computation
183-
G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2)
184-
G[2] = 2.0 * p[2] * (x[2] - x[1]^2)
179+
# Inject Inf into gradient on every 7th iteration
180+
state.grad .= Inf
185181
end
186-
return nothing
182+
return false
187183
end
188184

189-
optprob_inf = OptimizationFunction(rosenbrock; grad = custom_grad_inf!)
190-
prob_inf = OptimizationProblem(optprob_inf, x0, _p)
191-
192-
sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false)
185+
sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = inf_callback)
193186

194187
@test sol_inf.stats.iterations == 20
195188
@test all(!isnan, sol_inf.u)

0 commit comments

Comments
 (0)