@@ -42,27 +42,27 @@ function SciMLBase.__solve(cache::OptimizationCache{
4242 P,
4343 C
4444}
45- maxiters = if cache. solver_args. epochs === nothing
45+ if OptimizationBase. isa_dataiterator (cache. p)
46+ data = cache. p
47+ dataiterate = true
48+ else
49+ data = [cache. p]
50+ dataiterate = false
51+ end
52+
53+ epochs = if cache. solver_args. epochs === nothing
4654 if cache. solver_args. maxiters === nothing
47- throw (ArgumentError (" The number of epochs must be specified with either the epochs or maxiters kwarg." ))
55+ throw (ArgumentError (" The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data) ." ))
4856 else
49- cache. solver_args. maxiters
57+ cache. solver_args. maxiters / length (data)
5058 end
5159 else
5260 cache. solver_args. epochs
5361 end
5462
55- maxiters = Optimization. _check_and_convert_maxiters (maxiters)
56- if maxiters === nothing
57- throw (ArgumentError (" The number of epochs must be specified as the epochs or maxiters kwarg." ))
58- end
59-
60- if OptimizationBase. isa_dataiterator (cache. p)
61- data = cache. p
62- dataiterate = true
63- else
64- data = [cache. p]
65- dataiterate = false
63+ epochs = Optimization. _check_and_convert_maxiters (epochs)
64+ if epochs === nothing
65+ throw (ArgumentError (" The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)." ))
6666 end
6767
6868 opt = cache. opt
@@ -75,21 +75,35 @@ function SciMLBase.__solve(cache::OptimizationCache{
7575 min_θ = cache. u0
7676
7777 state = Optimisers. setup (opt, θ)
78-
78+ iterations = 0
79+ fevals = 0
80+ gevals = 0
7981 t0 = time ()
8082 Optimization. @withprogress cache. progress name= " Training" begin
81- for epoch in 1 : maxiters
83+ for epoch in 1 : epochs
8284 for (i, d) in enumerate (data)
8385 if cache. f. fg != = nothing && dataiterate
8486 x = cache. f. fg (G, θ, d)
87+ iterations += 1
88+ fevals += 1
89+ gevals += 1
8590 elseif dataiterate
8691 cache. f. grad (G, θ, d)
8792 x = cache. f (θ, d)
93+ iterations += 1
94+ fevals += 2
95+ gevals += 1
8896 elseif cache. f. fg != = nothing
8997 x = cache. f. fg (G, θ)
98+ iterations += 1
99+ fevals += 1
100+ gevals += 1
90101 else
91102 cache. f. grad (G, θ)
92103 x = cache. f (θ)
104+ iterations += 1
105+ fevals += 2
106+ gevals += 1
93107 end
94108 opt_state = Optimization. OptimizationState (
95109 iter = i + (epoch - 1 ) * length (data),
@@ -112,7 +126,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
112126 min_err = x
113127 min_θ = copy (θ)
114128 end
115- if i == maxiters # Last iter, revert to best.
129+ if i == length (data) # Last iter, revert to best.
116130 opt = min_opt
117131 x = min_err
118132 θ = min_θ
@@ -132,10 +146,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
132146 end
133147
134148 t1 = time ()
135- stats = Optimization. OptimizationStats (; iterations = maxiters ,
136- time = t1 - t0, fevals = maxiters , gevals = maxiters )
149+ stats = Optimization. OptimizationStats (; iterations,
150+ time = t1 - t0, fevals, gevals)
137151 SciMLBase. build_solution (cache, cache. opt, θ, first (x)[1 ], stats = stats)
138- # here should be build_solution to create the output message
139152end
140153
141154end
0 commit comments