Skip to content

Commit 9ae7839

Browse files
Address PR feedback
- Remove Functors dependency and use simple all(isfinite, G) check - Make warning conditional on cache.progress flag - Rewrite tests to use functions that return NaN/Inf in certain regions instead of callback-based approach
1 parent 0295dcf commit 9ae7839

File tree

3 files changed

+25
-48
lines changed

3 files changed

+25
-48
lines changed

lib/OptimizationOptimisers/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Vaibhav Dixit <[email protected]> and contributors"]
44
version = "0.3.13"
55

66
[deps]
7-
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
87
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
98
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
109
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -16,7 +15,6 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1615
OptimizationBase = {path = "../OptimizationBase"}
1716

1817
[compat]
19-
Functors = "0.4, 0.5"
2018
Logging = "1.10"
2119
Optimisers = "0.2, 0.3, 0.4"
2220
OptimizationBase = "4"

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,11 @@ module OptimizationOptimisers
33
using Reexport, Logging
44
@reexport using Optimisers, OptimizationBase
55
using SciMLBase
6-
using Functors
76

87
SciMLBase.has_init(opt::AbstractRule) = true
98
SciMLBase.requiresgradient(opt::AbstractRule) = true
109
SciMLBase.allowsfg(opt::AbstractRule) = true
1110

12-
# Helper function to check if gradients contain NaN or Inf
13-
function has_nan_or_inf(x)
14-
result = Ref(false)
15-
Functors.fmap(x) do val
16-
if val isa Number && (!isfinite(val))
17-
result[] = true
18-
end
19-
return val
20-
end
21-
return result[]
22-
end
23-
2411
function SciMLBase.__init(
2512
prob::SciMLBase.OptimizationProblem, opt::AbstractRule;
2613
callback = (args...) -> (false),
@@ -144,9 +131,9 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
144131
end
145132
end
146133
# Skip update if gradient contains NaN or Inf values
147-
if !has_nan_or_inf(G)
134+
if all(isfinite, G)
148135
state, θ = Optimisers.update(state, θ, G)
149-
else
136+
elseif cache.progress
150137
@warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10
151138
end
152139
end

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -137,54 +137,46 @@ end
137137

138138
@testset "NaN/Inf gradient handling" begin
139139
# Test that optimizer skips updates when gradients contain NaN or Inf
140-
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
140+
# Function that returns NaN when parameters are in certain regions
141+
function weird_nan_function(x, p)
142+
# Return NaN when x[1] is close to certain values to simulate numerical issues
143+
if abs(x[1] - 0.3) < 0.05 || abs(x[1] + 0.3) < 0.05
144+
return NaN
145+
end
146+
return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
147+
end
148+
141149
x0 = zeros(2)
142150
_p = [1.0, 100.0]
143151

144-
# Test with NaN gradients using Zygote
145-
# We'll use a callback to inject NaN into some iterations
146-
grad_counter = Ref(0)
147-
148-
# Create optimization problem with automatic differentiation
149-
optprob = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote())
152+
optprob = OptimizationFunction(weird_nan_function, OptimizationBase.AutoZygote())
150153
prob = OptimizationProblem(optprob, x0, _p)
151154

152-
# Use a callback that modifies the gradient to inject NaN periodically
153-
function nan_callback(state, l)
154-
grad_counter[] += 1
155-
if grad_counter[] % 5 == 0
156-
# Inject NaN into gradient on every 5th iteration
157-
state.grad .= NaN
158-
end
159-
return false
160-
end
161-
162155
# Should not throw error and should complete all iterations
163-
sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = nan_callback)
156+
sol = solve(prob, Optimisers.Adam(0.01), maxiters = 50, progress = false)
164157

165158
# Verify solution completed all iterations
166-
@test sol.stats.iterations == 20
159+
@test sol.stats.iterations == 50
167160

168161
# Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients)
169162
@test all(!isnan, sol.u)
170163
@test all(isfinite, sol.u)
171164

172-
# Test with Inf gradients
173-
grad_counter_inf = Ref(0)
174-
prob_inf = OptimizationProblem(optprob, x0, _p)
175-
176-
function inf_callback(state, l)
177-
grad_counter_inf[] += 1
178-
if grad_counter_inf[] % 7 == 0
179-
# Inject Inf into gradient on every 7th iteration
180-
state.grad .= Inf
165+
# Function that returns Inf when parameters are in certain regions
166+
function weird_inf_function(x, p)
167+
# Return Inf when x[1] is close to certain values
168+
if abs(x[1] - 0.2) < 0.05 || abs(x[1] + 0.2) < 0.05
169+
return Inf
181170
end
182-
return false
171+
return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
183172
end
184173

185-
sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = inf_callback)
174+
optprob_inf = OptimizationFunction(weird_inf_function, OptimizationBase.AutoZygote())
175+
prob_inf = OptimizationProblem(optprob_inf, x0, _p)
176+
177+
sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 50, progress = false)
186178

187-
@test sol_inf.stats.iterations == 20
179+
@test sol_inf.stats.iterations == 50
188180
@test all(!isnan, sol_inf.u)
189181
@test all(isfinite, sol_inf.u)
190182
end

0 commit comments

Comments
 (0)