@@ -137,54 +137,46 @@ end
137137
138138@testset " NaN/Inf gradient handling" begin
139139 # Test that optimizer skips updates when gradients contain NaN or Inf
140- rosenbrock (x, p) = (p[1 ] - x[1 ])^ 2 + p[2 ] * (x[2 ] - x[1 ]^ 2 )^ 2
140+ # Function that returns NaN when parameters are in certain regions
141+ function weird_nan_function (x, p)
142+ # Return NaN when x[1] is close to certain values to simulate numerical issues
143+ if abs (x[1 ] - 0.3 ) < 0.05 || abs (x[1 ] + 0.3 ) < 0.05
144+ return NaN
145+ end
146+ return (p[1 ] - x[1 ])^ 2 + p[2 ] * (x[2 ] - x[1 ]^ 2 )^ 2
147+ end
148+
141149 x0 = zeros (2 )
142150 _p = [1.0 , 100.0 ]
143151
144- # Test with NaN gradients using Zygote
145- # We'll use a callback to inject NaN into some iterations
146- grad_counter = Ref (0 )
147-
148- # Create optimization problem with automatic differentiation
149- optprob = OptimizationFunction (rosenbrock, OptimizationBase. AutoZygote ())
152+ optprob = OptimizationFunction (weird_nan_function, OptimizationBase. AutoZygote ())
150153 prob = OptimizationProblem (optprob, x0, _p)
151154
152- # Use a callback that modifies the gradient to inject NaN periodically
153- function nan_callback (state, l)
154- grad_counter[] += 1
155- if grad_counter[] % 5 == 0
156- # Inject NaN into gradient on every 5th iteration
157- state. grad .= NaN
158- end
159- return false
160- end
161-
162155 # Should not throw error and should complete all iterations
163- sol = solve (prob, Optimisers. Adam (0.01 ), maxiters = 20 , progress = false , callback = nan_callback )
156+ sol = solve (prob, Optimisers. Adam (0.01 ), maxiters = 50 , progress = false )
164157
165158 # Verify solution completed all iterations
166- @test sol. stats. iterations == 20
159+ @test sol. stats. iterations == 50
167160
168161 # Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients)
169162 @test all (! isnan, sol. u)
170163 @test all (isfinite, sol. u)
171164
172- # Test with Inf gradients
173- grad_counter_inf = Ref (0 )
174- prob_inf = OptimizationProblem (optprob, x0, _p)
175-
176- function inf_callback (state, l)
177- grad_counter_inf[] += 1
178- if grad_counter_inf[] % 7 == 0
179- # Inject Inf into gradient on every 7th iteration
180- state. grad .= Inf
165+ # Function that returns Inf when parameters are in certain regions
166+ function weird_inf_function (x, p)
167+ # Return Inf when x[1] is close to certain values
168+ if abs (x[1 ] - 0.2 ) < 0.05 || abs (x[1 ] + 0.2 ) < 0.05
169+ return Inf
181170 end
182- return false
171+ return (p[ 1 ] - x[ 1 ]) ^ 2 + p[ 2 ] * (x[ 2 ] - x[ 1 ] ^ 2 ) ^ 2
183172 end
184173
185- sol_inf = solve (prob_inf, Optimisers. Adam (0.01 ), maxiters = 20 , progress = false , callback = inf_callback)
174+ optprob_inf = OptimizationFunction (weird_inf_function, OptimizationBase. AutoZygote ())
175+ prob_inf = OptimizationProblem (optprob_inf, x0, _p)
176+
177+ sol_inf = solve (prob_inf, Optimisers. Adam (0.01 ), maxiters = 50 , progress = false )
186178
187- @test sol_inf. stats. iterations == 20
179+ @test sol_inf. stats. iterations == 50
188180 @test all (! isnan, sol_inf. u)
189181 @test all (isfinite, sol_inf. u)
190182end
0 commit comments