Skip to content

Commit 6e1999d

Browse files
tests pass now pls
1 parent 5cf459a commit 6e1999d

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ LBFGSB = "0.4.1"
2929
LinearAlgebra = "1.10"
3030
Logging = "1.10"
3131
LoggingExtras = "0.4, 1"
32+
MLUtils = "0.4.4"
3233
OptimizationBase = "2.0.2"
3334
Printf = "1.10"
3435
ProgressLogging = "0.1"

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,18 @@ function SciMLBase.__solve(cache::OptimizationCache{
159159
return cache.sense === Optimization.MaxSense ? -__x : __x
160160
end
161161

162-
fg! = function (G, θ)
163-
if G !== nothing
164-
cache.f.grad(G, θ)
165-
if cache.sense === Optimization.MaxSense
166-
G .*= -one(eltype(G))
162+
if cache.f.fg === nothing
163+
fg! = function (G, θ)
164+
if G !== nothing
165+
cache.f.grad(G, θ)
166+
if cache.sense === Optimization.MaxSense
167+
G .*= -one(eltype(G))
168+
end
167169
end
170+
return _loss(θ)
168171
end
169-
return _loss(θ)
172+
else
173+
fg! = cache.f.fg
170174
end
171175

172176
if cache.opt isa Optim.KrylovTrustRegion

lib/OptimizationOptimisers/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1212
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1313

1414
[compat]
15+
MLUtils = "0.4.4"
1516
Optimisers = "0.2, 0.3"
1617
Optimization = "3.21"
1718
ProgressLogging = "0.1"

0 commit comments

Comments
 (0)