Skip to content

Commit 2a803ff

Browse files
tests pass now pls
1 parent 6e1999d commit 2a803ff

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ LinearAlgebra = "1.10"
3030
Logging = "1.10"
3131
LoggingExtras = "0.4, 1"
3232
MLUtils = "0.4.4"
33-
OptimizationBase = "2.0.2"
33+
OptimizationBase = "2.0.3"
3434
Printf = "1.10"
3535
ProgressLogging = "0.1"
3636
Reexport = "1.2"

src/sophia.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Optimization.LinearAlgebra
1+
using Optimization.LinearAlgebra, MLUtils
22

33
struct Sophia
44
η::Float64
@@ -80,14 +80,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
8080
for _ in 1:maxiters
8181
for (i, d) in enumerate(data)
8282
if cache.f.fg !== nothing && dataiterate
83-
x = cache.f.fg(G, θ, d)
83+
x = cache.f.fg(gₜ, θ, d)
8484
elseif dataiterate
85-
cache.f.grad(G, θ, d)
85+
cache.f.grad(gₜ, θ, d)
8686
x = cache.f(θ, d)
8787
elseif cache.f.fg !== nothing
88-
x = cache.f.fg(G, θ)
88+
x = cache.f.fg(gₜ, θ)
8989
else
90-
cache.f.grad(G, θ)
90+
cache.f.grad(gₜ, θ)
9191
x = cache.f(θ)
9292
end
9393
opt_state = Optimization.OptimizationState(; iter = i,

test/minibatch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function dudt_(u, p, t)
1919
ann(u, p, st)[1] .* u
2020
end
2121

22-
function callback(state, l, pred) #callback function to observe training
22+
function callback(state, l) #callback function to observe training
2323
display(l)
2424
return false
2525
end

0 commit comments

Comments
 (0)