|
1 | 1 | using OptimizationOptimisers, ForwardDiff, Optimization |
2 | 2 | using Test |
3 | 3 | using Zygote |
| 4 | +using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices |
4 | 5 |
|
5 | 6 | @testset "OptimizationOptimisers.jl" begin |
6 | 7 | rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 |
@@ -73,9 +74,6 @@ using Zygote |
73 | 74 | end |
74 | 75 |
|
75 | 76 | @testset "Minibatching" begin |
76 | | - using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, |
77 | | - ComponentArrays |
78 | | - |
79 | 77 | x = rand(Float32, 10000) |
80 | 78 | y = sin.(x) |
81 | 79 | data = MLUtils.DataLoader((x, y), batchsize = 100) |
|
87 | 85 | smodel = StatefulLuxLayer{true}(model, nothing, st) |
88 | 86 |
|
89 | 87 | function callback(state, l) |
90 | | - state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l |
| 88 | + state.iter % 25 == 1 && Printf.@printf "Iteration: %5d, Loss: %.6e\n" state.iter l |
91 | 89 | return l < 1e-4 |
92 | 90 | end |
93 | 91 |
|
|
101 | 99 |
|
102 | 100 | res = Optimization.solve(prob, Optimisers.Adam(), epochs = 50) |
103 | 101 |
|
104 | | - @test res.objective < 1e-4 |
105 | 102 | @test res.stats.iterations == 50*length(data) |
106 | 103 | @test res.stats.fevals == 50*length(data) |
107 | 104 | @test res.stats.gevals == 50*length(data) |
|
110 | 107 |
|
111 | 108 | @test res.objective < 1e-4 |
112 | 109 |
|
113 | | - using MLDataDevices |
114 | 110 | data = CPUDevice()(data) |
115 | 111 | optf = OptimizationFunction(loss, AutoZygote()) |
116 | 112 | prob = OptimizationProblem(optf, ps_ca, data) |
|
0 commit comments