|
1 | 1 | module OptimizationOptimisers |
2 | 2 |
|
3 | | -using Reexport, Printf, ProgressLogging |
| 3 | +using Reexport, ProgressLogging, UUIDs |
4 | 4 | @reexport using Optimisers, OptimizationBase |
5 | 5 | using SciMLBase |
6 | 6 |
|
@@ -95,77 +95,73 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ |
95 | 95 | gevals = 0 |
96 | 96 | t0 = time() |
97 | 97 | breakall = false |
98 | | - begin |
99 | | - for epoch in 1:epochs |
100 | | - if breakall |
101 | | - break |
| 98 | + progress_id = uuid4() |
| 99 | + for epoch in 1:epochs, d in data |
| 100 | + if cache.f.fg !== nothing && dataiterate |
| 101 | + x = cache.f.fg(G, θ, d) |
| 102 | + iterations += 1 |
| 103 | + fevals += 1 |
| 104 | + gevals += 1 |
| 105 | + elseif dataiterate |
| 106 | + cache.f.grad(G, θ, d) |
| 107 | + x = cache.f(θ, d) |
| 108 | + iterations += 1 |
| 109 | + fevals += 2 |
| 110 | + gevals += 1 |
| 111 | + elseif cache.f.fg !== nothing |
| 112 | + x = cache.f.fg(G, θ) |
| 113 | + iterations += 1 |
| 114 | + fevals += 1 |
| 115 | + gevals += 1 |
| 116 | + else |
| 117 | + cache.f.grad(G, θ) |
| 118 | + x = cache.f(θ) |
| 119 | + iterations += 1 |
| 120 | + fevals += 2 |
| 121 | + gevals += 1 |
| 122 | + end |
| 123 | + opt_state = OptimizationBase.OptimizationState( |
| 124 | + iter = iterations, |
| 125 | + u = θ, |
| 126 | + p = d, |
| 127 | + objective = x[1], |
| 128 | + grad = G, |
| 129 | + original = state) |
| 130 | + breakall = cache.callback(opt_state, x...) |
| 131 | + if !(breakall isa Bool) |
| 132 | + error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.") |
| 133 | + elseif breakall |
| 134 | + break |
| 135 | + end |
| 136 | + cache.progress && |
| 137 | + @info ProgressLogging.Progress(progress_id, iterations / maxiters; |
| 138 | + name = "loss: $(round(first(first(x)); digits=3))") |
| 139 | + |
| 140 | + if cache.solver_args.save_best |
| 141 | + if first(x)[1] < first(min_err)[1] #found a better solution |
| 142 | + min_opt = opt |
| 143 | + min_err = x |
| 144 | + min_θ = copy(θ) |
102 | 145 | end |
103 | | - for (i, d) in enumerate(data) |
104 | | - if cache.f.fg !== nothing && dataiterate |
105 | | - x = cache.f.fg(G, θ, d) |
106 | | - iterations += 1 |
107 | | - fevals += 1 |
108 | | - gevals += 1 |
109 | | - elseif dataiterate |
110 | | - cache.f.grad(G, θ, d) |
111 | | - x = cache.f(θ, d) |
112 | | - iterations += 1 |
113 | | - fevals += 2 |
114 | | - gevals += 1 |
115 | | - elseif cache.f.fg !== nothing |
116 | | - x = cache.f.fg(G, θ) |
117 | | - iterations += 1 |
118 | | - fevals += 1 |
119 | | - gevals += 1 |
120 | | - else |
121 | | - cache.f.grad(G, θ) |
122 | | - x = cache.f(θ) |
123 | | - iterations += 1 |
124 | | - fevals += 2 |
125 | | - gevals += 1 |
126 | | - end |
127 | | - opt_state = OptimizationBase.OptimizationState( |
128 | | - iter = i + (epoch - 1) * length(data), |
| 146 | + if iterations == length(data) * epochs #Last iter, revert to best. |
| 147 | + opt = min_opt |
| 148 | + x = min_err |
| 149 | + θ = min_θ |
| 150 | + cache.f.grad(G, θ, d) |
| 151 | + opt_state = OptimizationBase.OptimizationState(iter = iterations, |
129 | 152 | u = θ, |
130 | 153 | p = d, |
131 | 154 | objective = x[1], |
132 | 155 | grad = G, |
133 | 156 | original = state) |
134 | 157 | breakall = cache.callback(opt_state, x...) |
135 | | - if !(breakall isa Bool) |
136 | | - error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.") |
137 | | - elseif breakall |
138 | | - break |
139 | | - end |
140 | | - msg = @sprintf("loss: %.3g", first(x)[1]) |
141 | | - cache.progress && ProgressLogging.@logprogress msg iterations/maxiters |
142 | | - |
143 | | - if cache.solver_args.save_best |
144 | | - if first(x)[1] < first(min_err)[1] #found a better solution |
145 | | - min_opt = opt |
146 | | - min_err = x |
147 | | - min_θ = copy(θ) |
148 | | - end |
149 | | - if iterations == length(data) * epochs #Last iter, revert to best. |
150 | | - opt = min_opt |
151 | | - x = min_err |
152 | | - θ = min_θ |
153 | | - cache.f.grad(G, θ, d) |
154 | | - opt_state = OptimizationBase.OptimizationState(iter = iterations, |
155 | | - u = θ, |
156 | | - p = d, |
157 | | - objective = x[1], |
158 | | - grad = G, |
159 | | - original = state) |
160 | | - breakall = cache.callback(opt_state, x...) |
161 | | - break |
162 | | - end |
163 | | - end |
164 | | - state, θ = Optimisers.update(state, θ, G) |
| 158 | + break |
165 | 159 | end |
166 | 160 | end |
| 161 | + state, θ = Optimisers.update(state, θ, G) |
167 | 162 | end |
168 | 163 |
|
| 164 | + cache.progress && @info ProgressLogging.Progress(progress_id; done = true) |
169 | 165 | t1 = time() |
170 | 166 | stats = OptimizationBase.OptimizationStats(; iterations, |
171 | 167 | time = t1 - t0, fevals, gevals) |
|
0 commit comments