Skip to content

Commit e54b6c8

Browse files
Merge pull request #872 from SciML/Vaibhavdixit02-patch-1
Update runtests.jl in OptimizationOptimisers
2 parents 62e6853 + 6c75ddd commit e54b6c8

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using OptimizationOptimisers, ForwardDiff, Optimization
22
using Test
33
using Zygote
4+
using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices
45

56
@testset "OptimizationOptimisers.jl" begin
67
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
@@ -73,9 +74,6 @@ using Zygote
7374
end
7475

7576
@testset "Minibatching" begin
76-
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random,
77-
ComponentArrays
78-
7977
x = rand(Float32, 10000)
8078
y = sin.(x)
8179
data = MLUtils.DataLoader((x, y), batchsize = 100)
@@ -87,7 +85,7 @@ end
8785
smodel = StatefulLuxLayer{true}(model, nothing, st)
8886

8987
function callback(state, l)
90-
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
88+
state.iter % 25 == 1 && Printf.@printf "Iteration: %5d, Loss: %.6e\n" state.iter l
9189
return l < 1e-4
9290
end
9391

@@ -101,7 +99,6 @@ end
10199

102100
res = Optimization.solve(prob, Optimisers.Adam(), epochs = 50)
103101

104-
@test res.objective < 1e-4
105102
@test res.stats.iterations == 50*length(data)
106103
@test res.stats.fevals == 50*length(data)
107104
@test res.stats.gevals == 50*length(data)
@@ -110,7 +107,6 @@ end
110107

111108
@test res.objective < 1e-4
112109

113-
using MLDataDevices
114110
data = CPUDevice()(data)
115111
optf = OptimizationFunction(loss, AutoZygote())
116112
prob = OptimizationProblem(optf, ps_ca, data)

0 commit comments

Comments
 (0)