Skip to content

Commit 5a2d692

Browse files
Add robust NaN/Inf checker using Functors.fmap
- Add has_nan_or_inf() helper function - Uses Functors.fmap to recursively check all elements - Handles arbitrary nested structures (arrays, ComponentArrays, etc.) - Checks if any element is not finite (catches both NaN and Inf)
1 parent 1aa8780 commit 5a2d692

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,24 @@ module OptimizationOptimisers
33
using Reexport, Logging
44
@reexport using Optimisers, OptimizationBase
55
using SciMLBase
6+
using Functors
67

78
SciMLBase.has_init(opt::AbstractRule) = true
89
SciMLBase.requiresgradient(opt::AbstractRule) = true
910
SciMLBase.allowsfg(opt::AbstractRule) = true
1011

12+
# Helper function to check if gradients contain NaN or Inf
13+
function has_nan_or_inf(x)
14+
result = Ref(false)
15+
Functors.fmap(x) do val
16+
if val isa Number && (!isfinite(val))
17+
result[] = true
18+
end
19+
return val
20+
end
21+
return result[]
22+
end
23+
1124
function SciMLBase.__init(
1225
prob::SciMLBase.OptimizationProblem, opt::AbstractRule;
1326
callback = (args...) -> (false),
@@ -130,8 +143,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
130143
end
131144
end
132145
# Skip update if gradient contains NaN or Inf values
133-
has_nan_or_inf = any(.!(isfinite.(G)))
134-
if !has_nan_or_inf
146+
if !has_nan_or_inf(G)
135147
state, θ = Optimisers.update(state, θ, G)
136148
else
137149
@warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10

0 commit comments

Comments
 (0)