Skip to content

Commit 7340165

Browse files
Merge pull request #1081 from ChrisRackauckas-Claude/fix-nan-gradient-handling
Skip gradient updates when gradients contain NaN or Inf
2 parents 846e6c1 + 16140e6 commit 7340165

File tree

3 files changed

+72
-21
lines changed

3 files changed

+72
-21
lines changed

lib/OptimizationOptimisers/Project.toml

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,34 @@ authors = ["Vaibhav Dixit <[email protected]> and contributors"]
44
version = "0.3.13"
55

66
[deps]
7-
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
8-
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
7+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
98
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
9+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
10+
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
1011
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
11-
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
12+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1213

13-
[extras]
14-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
16-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
17-
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
18-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
20-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
21-
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
22-
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
14+
[sources]
15+
OptimizationBase = {path = "../OptimizationBase"}
2316

2417
[compat]
25-
julia = "1.10"
26-
OptimizationBase = "4"
27-
SciMLBase = "2.122.1"
18+
Logging = "1.10"
2819
Optimisers = "0.2, 0.3, 0.4"
20+
OptimizationBase = "4"
2921
Reexport = "1.2"
30-
Logging = "1.10"
22+
SciMLBase = "2.122.1"
23+
julia = "1.10"
3124

32-
[sources]
33-
OptimizationBase = {path = "../OptimizationBase"}
25+
[extras]
26+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
27+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
28+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
29+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
30+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
31+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
32+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
33+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
34+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3435

3536
[targets]
3637
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"]

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
6767
breakall = false
6868
progress_id = :OptimizationOptimizersJL
6969
for epoch in 1:epochs, d in data
70+
7071
if cache.f.fg !== nothing && dataiterate
7172
x = cache.f.fg(G, θ, d)
7273
iterations += 1
@@ -106,7 +107,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
106107
if cache.progress
107108
message = "Loss: $(round(first(first(x)); digits = 3))"
108109
@logmsg(LogLevel(-1), "Optimization", _id=progress_id,
109-
message=message, progress=iterations / maxiters)
110+
message=message, progress=iterations/maxiters)
110111
end
111112
if cache.solver_args.save_best
112113
if first(x)[1] < first(min_err)[1] #found a better solution
@@ -129,7 +130,12 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
129130
break
130131
end
131132
end
132-
state, θ = Optimisers.update(state, θ, G)
133+
# Skip update if gradient contains NaN or Inf values
134+
if all(isfinite, G)
135+
state, θ = Optimisers.update(state, θ, G)
136+
elseif cache.progress
137+
@warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10
138+
end
133139
end
134140
cache.progress && @logmsg(LogLevel(-1), "Optimization",
135141
_id=progress_id, message="Done", progress=1.0)

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,47 @@ end
134134

135135
@test res.objective < 1e-4
136136
end
137+
138+
@testset "NaN/Inf gradient handling" begin
139+
# Test that optimizer skips updates when gradients contain NaN or Inf
140+
# Function that can produce NaN due to sqrt of negative number
141+
function weird_nan_function(x, p)
142+
val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
143+
# sqrt of a value that can become negative produces NaN
144+
val += sqrt(max(x[1], 0.0)) * 0.01
145+
return val
146+
end
147+
148+
x0 = [-0.5, 0.1] # Start with negative x[1] to trigger sqrt of negative
149+
_p = [1.0, 100.0]
150+
151+
optprob = OptimizationFunction(weird_nan_function, OptimizationBase.AutoZygote())
152+
prob = OptimizationProblem(optprob, x0, _p)
153+
154+
# Should not throw error and should complete all iterations
155+
sol = solve(prob, Optimisers.Adam(0.01), maxiters = 50, progress = false)
156+
157+
# Verify solution completed all iterations
158+
@test sol.stats.iterations == 50
159+
160+
# Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients)
161+
@test all(!isnan, sol.u)
162+
@test all(isfinite, sol.u)
163+
164+
# Function with 1/x that can produce Inf gradient when x is very small
165+
function weird_inf_function(x, p)
166+
val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
167+
# 1/(x[1] + 0.01) can have very large gradient near x[1] = -0.01
168+
val += 0.01 / (abs(x[1] - 0.1) + 1e-8)
169+
return val
170+
end
171+
172+
optprob_inf = OptimizationFunction(weird_inf_function, OptimizationBase.AutoZygote())
173+
prob_inf = OptimizationProblem(optprob_inf, x0, _p)
174+
175+
sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 50, progress = false)
176+
177+
@test sol_inf.stats.iterations == 50
178+
@test all(!isnan, sol_inf.u)
179+
@test all(isfinite, sol_inf.u)
180+
end

0 commit comments

Comments
 (0)