|
134 | 134 |
|
135 | 135 | @test res.objective < 1e-4 |
136 | 136 | end |
| 137 | + |
| 138 | +@testset "NaN/Inf gradient handling" begin |
| 139 | + # Test that optimizer skips updates when gradients contain NaN or Inf |
| 140 | + rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 |
| 141 | + x0 = zeros(2) |
| 142 | + _p = [1.0, 100.0] |
| 143 | + |
| 144 | + # Counter to track gradient evaluations |
| 145 | + grad_counter = Ref(0) |
| 146 | + |
| 147 | + # Custom gradient function that returns NaN on every 5th call |
| 148 | + function custom_grad!(G, x, p) |
| 149 | + grad_counter[] += 1 |
| 150 | + if grad_counter[] % 5 == 0 |
| 151 | + # Inject NaN into gradient |
| 152 | + G .= NaN |
| 153 | + else |
| 154 | + # Normal gradient computation |
| 155 | + G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2) |
| 156 | + G[2] = 2.0 * p[2] * (x[2] - x[1]^2) |
| 157 | + end |
| 158 | + return nothing |
| 159 | + end |
| 160 | + |
| 161 | + optprob = OptimizationFunction(rosenbrock; grad = custom_grad!) |
| 162 | + prob = OptimizationProblem(optprob, x0, _p) |
| 163 | + |
| 164 | + # Should not throw error and should complete all iterations |
| 165 | + sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false) |
| 166 | + |
| 167 | + # Verify solution completed all iterations |
| 168 | + @test sol.stats.iterations == 20 |
| 169 | + |
| 170 | + # Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients) |
| 171 | + @test all(!isnan, sol.u) |
| 172 | + @test all(isfinite, sol.u) |
| 173 | + |
| 174 | + # Test with Inf gradients |
| 175 | + grad_counter_inf = Ref(0) |
| 176 | + function custom_grad_inf!(G, x, p) |
| 177 | + grad_counter_inf[] += 1 |
| 178 | + if grad_counter_inf[] % 7 == 0 |
| 179 | + # Inject Inf into gradient |
| 180 | + G .= Inf |
| 181 | + else |
| 182 | + # Normal gradient computation |
| 183 | + G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2) |
| 184 | + G[2] = 2.0 * p[2] * (x[2] - x[1]^2) |
| 185 | + end |
| 186 | + return nothing |
| 187 | + end |
| 188 | + |
| 189 | + optprob_inf = OptimizationFunction(rosenbrock; grad = custom_grad_inf!) |
| 190 | + prob_inf = OptimizationProblem(optprob_inf, x0, _p) |
| 191 | + |
| 192 | + sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false) |
| 193 | + |
| 194 | + @test sol_inf.stats.iterations == 20 |
| 195 | + @test all(!isnan, sol_inf.u) |
| 196 | + @test all(isfinite, sol_inf.u) |
| 197 | +end |
0 commit comments