Skip to content

Commit 4fa329c

Browse files
Add regression test for issue #995
This adds a test to ensure that OptimizationState correctly includes the parameters `p` field when using OptimizationOptimisers. The issue was already fixed in commit 0d9efc2, but this test ensures the functionality is maintained. Fixes #995 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 0cb87b1 commit 4fa329c

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,29 @@ using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices
9090
end
9191

9292
@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
93+
94+
@testset "Issue #995: Parameters in OptimizationState" begin
95+
# Regression test for https://github.com/SciML/Optimization.jl/issues/995
96+
# Ensure that OptimizationState contains the parameters p in callbacks
97+
rosenbrock_ = (x,p) -> (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
98+
x0 = zeros(2)
99+
p = [1.0, 100.0]
100+
101+
optfun = OptimizationFunction(rosenbrock_, Optimization.AutoForwardDiff())
102+
prob = OptimizationProblem(optfun, x0, p)
103+
104+
parameters_seen = []
105+
function cb(state, l)
106+
push!(parameters_seen, state.p)
107+
return false
108+
end
109+
110+
sol = solve(prob, Optimisers.Adam(0.1), maxiters=10, callback=cb)
111+
112+
# Check that all parameters seen in callbacks match the expected parameters
113+
@test all(p_seen -> p_seen == p, parameters_seen)
114+
@test length(parameters_seen) == 11 # One callback per iteration plus final callback
115+
end
93116
end
94117

95118
@testset "Minibatching" begin

0 commit comments

Comments
 (0)