@@ -141,28 +141,26 @@ end
141141 x0 = zeros (2 )
142142 _p = [1.0 , 100.0 ]
143143
144- # Counter to track gradient evaluations
144+ # Test with NaN gradients using Zygote
145+ # We'll use a callback to inject NaN into some iterations
145146 grad_counter = Ref (0 )
146147
147- # Custom gradient function that returns NaN on every 5th call
148- function custom_grad! (G, x, p)
148+ # Create optimization problem with automatic differentiation
149+ optprob = OptimizationFunction (rosenbrock, OptimizationBase. AutoZygote ())
150+ prob = OptimizationProblem (optprob, x0, _p)
151+
152+ # Use a callback that modifies the gradient to inject NaN periodically
153+ function nan_callback (state, l)
149154 grad_counter[] += 1
150155 if grad_counter[] % 5 == 0
151- # Inject NaN into gradient
152- G .= NaN
153- else
154- # Normal gradient computation
155- G[1 ] = - 2.0 * (p[1 ] - x[1 ]) - 4.0 * p[2 ] * x[1 ] * (x[2 ] - x[1 ]^ 2 )
156- G[2 ] = 2.0 * p[2 ] * (x[2 ] - x[1 ]^ 2 )
156+ # Inject NaN into gradient on every 5th iteration
157+ state. grad .= NaN
157158 end
158- return nothing
159+ return false
159160 end
160161
161- optprob = OptimizationFunction (rosenbrock; grad = custom_grad!)
162- prob = OptimizationProblem (optprob, x0, _p)
163-
164162 # Should not throw error and should complete all iterations
165- sol = solve (prob, Optimisers. Adam (0.01 ), maxiters = 20 , progress = false )
163+ sol = solve (prob, Optimisers. Adam (0.01 ), maxiters = 20 , progress = false , callback = nan_callback )
166164
167165 # Verify solution completed all iterations
168166 @test sol. stats. iterations == 20
@@ -173,23 +171,18 @@ end
173171
174172 # Test with Inf gradients
175173 grad_counter_inf = Ref (0 )
176- function custom_grad_inf! (G, x, p)
174+ prob_inf = OptimizationProblem (optprob, x0, _p)
175+
176+ function inf_callback (state, l)
177177 grad_counter_inf[] += 1
178178 if grad_counter_inf[] % 7 == 0
179- # Inject Inf into gradient
180- G .= Inf
181- else
182- # Normal gradient computation
183- G[1 ] = - 2.0 * (p[1 ] - x[1 ]) - 4.0 * p[2 ] * x[1 ] * (x[2 ] - x[1 ]^ 2 )
184- G[2 ] = 2.0 * p[2 ] * (x[2 ] - x[1 ]^ 2 )
179+ # Inject Inf into gradient on every 7th iteration
180+ state. grad .= Inf
185181 end
186- return nothing
182+ return false
187183 end
188184
189- optprob_inf = OptimizationFunction (rosenbrock; grad = custom_grad_inf!)
190- prob_inf = OptimizationProblem (optprob_inf, x0, _p)
191-
192- sol_inf = solve (prob_inf, Optimisers. Adam (0.01 ), maxiters = 20 , progress = false )
185+ sol_inf = solve (prob_inf, Optimisers. Adam (0.01 ), maxiters = 20 , progress = false , callback = inf_callback)
193186
194187 @test sol_inf. stats. iterations == 20
195188 @test all (! isnan, sol_inf. u)
0 commit comments