Skip to content

Commit 5a76a9e

Browse files
Add minibatching tests
1 parent 8d7cd3a commit 5a76a9e

File tree

4 files changed

+47
-6
lines changed

4 files changed

+47
-6
lines changed

lib/OptimizationOptimisers/Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1111
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1212

13-
[extensions]
14-
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
15-
OptimizationOptimisersMLUtilsExt = "MLUtils"
16-
1713
[weakdeps]
1814
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1915
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2016

17+
[extensions]
18+
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
19+
OptimizationOptimisersMLUtilsExt = "MLUtils"
20+
2121
[compat]
2222
MLDataDevices = "1.1"
2323
MLUtils = "0.4.4"

lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ module OptimizationOptimisersMLDataDevicesExt
33
using MLDataDevices
44
using OptimizationOptimisers
55

6-
OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true
6+
OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true)
77

88
end

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
117117
opt = min_opt
118118
x = min_err
119119
θ = min_θ
120-
cache.f.grad(G, θ, d...)
120+
cache.f.grad(G, θ, d)
121121
opt_state = Optimization.OptimizationState(iter = i,
122122
u = θ,
123123
objective = x[1],

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,44 @@ using Zygote
6868

6969
@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
7070
end
71+
72+
@testset "Minibatching" begin
73+
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Statistics, Plots,
74+
Random, ComponentArrays
75+
76+
x = rand(10000)
77+
y = sin.(x)
78+
data = MLUtils.DataLoader((x, y), batchsize = 100)
79+
80+
# Define the neural network
81+
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
82+
ps, st = Lux.setup(Random.default_rng(), model)
83+
ps_ca = ComponentArray(ps)
84+
smodel = StatefulLuxLayer{true}(model, nothing, st)
85+
86+
function callback(state, l)
87+
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
88+
return l < 1e-4
89+
end
90+
91+
function loss(ps, data)
92+
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
93+
return sum(abs2, ypred .- data[2])
94+
end
95+
96+
optf = OptimizationFunction(loss, AutoZygote())
97+
prob = OptimizationProblem(optf, ps_ca, data)
98+
99+
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)
100+
101+
@test res.objective < 1e-4
102+
103+
using MLDataDevices
104+
data = CPUDevice()(data)
105+
optf = OptimizationFunction(loss, AutoZygote())
106+
prob = OptimizationProblem(optf, ps_ca, data)
107+
108+
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)
109+
110+
@test res.objective < 1e-4
111+
end

0 commit comments

Comments
 (0)