|
1 | 1 | defmodule Axon.Optimizers do
|
2 | 2 | @moduledoc false
|
3 |
| - alias Polaris.Updates |
4 | 3 |
|
5 |
| - @doc """ |
6 |
| - Adabelief optimizer. |
7 |
| -
|
8 |
| - ## Options |
9 |
| -
|
10 |
| - * `:b1` - first moment decay. Defaults to `0.9` |
11 |
| - * `:b2` - second moment decay. Defaults to `0.999` |
12 |
| - * `:eps` - numerical stability term. Defaults to `0.0` |
13 |
| - * `:eps_root` - numerical stability term. Defaults to `1.0e-16` |
14 |
| -
|
15 |
| - ## References |
16 |
| -
|
17 |
| - * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) |
18 |
| - """ |
19 | 4 | @deprecated "Use Polaris.Optimizers.adabelief/1 instead"
|
20 | 5 | def adabelief(learning_rate \\ 1.0e-3, opts \\ []) do
|
21 |
| - Updates.scale_by_belief(opts) |
22 |
| - |> scale_by_learning_rate(learning_rate) |
| 6 | + Polaris.Optimizers.adabelief([learning_rate: learning_rate] ++ opts) |
23 | 7 | end
|
24 | 8 |
|
25 |
| - @doc """ |
26 |
| - Adagrad optimizer. |
27 |
| -
|
28 |
| - ## Options |
29 |
| -
|
30 |
| - * `:eps` - numerical stability term. Defaults to `1.0e-7` |
31 |
| -
|
32 |
| - ## References |
33 |
| -
|
34 |
| - * [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) |
35 |
| - """ |
36 | 9 | @deprecated "Use Polaris.Optimizers.adagrad/1 instead"
|
37 | 10 | def adagrad(learning_rate \\ 1.0e-3, opts \\ []) do
|
38 |
| - Updates.scale_by_rss(opts) |
39 |
| - |> scale_by_learning_rate(learning_rate) |
| 11 | + Polaris.Optimizers.adagrad([learning_rate: learning_rate] ++ opts) |
40 | 12 | end
|
41 | 13 |
|
42 |
| - @doc """ |
43 |
| - Adam optimizer. |
44 |
| -
|
45 |
| - ## Options |
46 |
| -
|
47 |
| - * `:b1` - first moment decay. Defaults to `0.9` |
48 |
| - * `:b2` - second moment decay. Defaults to `0.999` |
49 |
| - * `:eps` - numerical stability term. Defaults to `1.0e-8` |
50 |
| - * `:eps_root` - numerical stability term. Defaults to `1.0e-15` |
51 |
| -
|
52 |
| - ## References |
53 |
| -
|
54 |
| - * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) |
55 |
| - """ |
56 | 14 | @deprecated "Use Polaris.Optimizers.adam/1 instead"
|
57 | 15 | def adam(learning_rate \\ 1.0e-3, opts \\ []) do
|
58 |
| - Updates.scale_by_adam(opts) |
59 |
| - |> scale_by_learning_rate(learning_rate) |
| 16 | + Polaris.Optimizers.adam([learning_rate: learning_rate] ++ opts) |
60 | 17 | end
|
61 | 18 |
|
62 |
| - @doc """ |
63 |
| - Adam with weight decay optimizer. |
64 |
| -
|
65 |
| - ## Options |
66 |
| -
|
67 |
| - * `:b1` - first moment decay. Defaults to `0.9` |
68 |
| - * `:b2` - second moment decay. Defaults to `0.999` |
69 |
| - * `:eps` - numerical stability term. Defaults to `1.0e-8` |
70 |
| - * `:eps_root` - numerical stability term. Defaults to `0.0` |
71 |
| - * `:decay` - weight decay. Defaults to `0.0` |
72 |
| - """ |
73 | 19 | @deprecated "Use Polaris.Optimizers.adamw/1 instead"
|
74 | 20 | def adamw(learning_rate \\ 1.0e-3, opts \\ []) do
|
75 |
| - {decay, opts} = Keyword.pop(opts, :decay, 0.0) |
76 |
| - |
77 |
| - Updates.scale_by_adam(opts) |
78 |
| - |> Updates.add_decayed_weights(decay: decay) |
79 |
| - |> scale_by_learning_rate(learning_rate) |
| 21 | + Polaris.Optimizers.adamw([learning_rate: learning_rate] ++ opts) |
80 | 22 | end
|
81 | 23 |
|
82 |
| - @doc """ |
83 |
| - Lamb optimizer. |
84 |
| -
|
85 |
| - ## Options |
86 |
| -
|
87 |
| - * `:b1` - first moment decay. Defaults to `0.9` |
88 |
| - * `:b2` - second moment decay. Defaults to `0.999` |
89 |
| - * `:eps` - numerical stability term. Defaults to `1.0e-8` |
90 |
| - * `:eps_root` - numerical stability term. Defaults to `0.0` |
91 |
| - * `:decay` - weight decay. Defaults to `0.0` |
92 |
| - * `:min_norm` - minimum norm value. Defaults to `0.0` |
93 |
| -
|
94 |
| - ## References |
95 |
| -
|
96 |
| - * [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962) |
97 |
| - """ |
98 | 24 | @deprecated "Use Polaris.Optimizers.lamb/1 instead"
|
99 | 25 | def lamb(learning_rate \\ 1.0e-2, opts \\ []) do
|
100 |
| - {decay, opts} = Keyword.pop(opts, :decay, 0.0) |
101 |
| - {min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0) |
102 |
| - |
103 |
| - Updates.scale_by_adam(opts) |
104 |
| - |> Updates.add_decayed_weights(decay: decay) |
105 |
| - |> Updates.scale_by_trust_ratio(min_norm: min_norm) |
106 |
| - |> scale_by_learning_rate(learning_rate) |
| 26 | + Polaris.Optimizers.lamb([learning_rate: learning_rate] ++ opts) |
107 | 27 | end
|
108 | 28 |
|
109 |
| - @doc """ |
110 |
| - Noisy SGD optimizer. |
111 |
| -
|
112 |
| - ## Options |
113 |
| -
|
114 |
| - * `:eta` - used to compute variance of noise distribution. Defaults to `0.1` |
115 |
| - * `:gamma` - used to compute variance of noise distribution. Defaults to `0.55` |
116 |
| - """ |
117 | 29 | @deprecated "Use Polaris.Optimizers.noisy_sgd/1 instead"
|
118 | 30 | def noisy_sgd(learning_rate \\ 1.0e-2, opts \\ []) do
|
119 |
| - scale_by_learning_rate(learning_rate) |
120 |
| - |> Updates.add_noise(opts) |
| 31 | + Polaris.Optimizers.noisy_sgd([learning_rate: learning_rate] ++ opts) |
121 | 32 | end
|
122 | 33 |
|
123 |
| - @doc """ |
124 |
| - Rectified Adam optimizer. |
125 |
| -
|
126 |
| - ## Options |
127 |
| -
|
128 |
| - * `:b1` - first moment decay. Defaults to `0.9` |
129 |
| - * `:b2` - second moment decay. Defaults to `0.999` |
130 |
| - * `:eps` - numerical stability term. Defaults to `1.0e-8` |
131 |
| - * `:eps_root` - numerical stability term. Defaults to `0.0` |
132 |
| - * `:threshold` - threshold term. Defaults to `5.0` |
133 |
| -
|
134 |
| - ## References |
135 |
| -
|
136 |
| - * [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf) |
137 |
| - """ |
138 | 34 | @deprecated "Use Polaris.Optimizers.radam/1 instead"
|
139 | 35 | def radam(learning_rate \\ 1.0e-3, opts \\ []) do
|
140 |
| - Updates.scale_by_radam(opts) |
141 |
| - |> scale_by_learning_rate(learning_rate) |
| 36 | + Polaris.Optimizers.radam([learning_rate: learning_rate] ++ opts) |
142 | 37 | end
|
143 | 38 |
|
144 |
| - @doc """ |
145 |
| - RMSProp optimizer. |
146 |
| -
|
147 |
| - ## Options |
148 |
| -
|
149 |
| - * `:centered` - whether to scale by centered root of EMA of squares. Defaults to `false` |
150 |
| - * `:momentum` - momentum term. If set, uses SGD with momentum and decay set |
151 |
| - to value of this term. |
152 |
| - * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` |
153 |
| - * `:initial_scale` - initial value of EMA. Defaults to `0.0` |
154 |
| - * `:decay` - EMA decay rate. Defaults to `0.9` |
155 |
| - * `:eps` - numerical stability term. Defaults to `1.0e-8` |
156 |
| - """ |
157 | 39 | @deprecated "Use Polaris.Optimizers.rmsprop/1 instead"
|
158 | 40 | def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do
|
159 |
| - {centered, opts} = Keyword.pop(opts, :centered, false) |
160 |
| - {nesterov?, opts} = Keyword.pop(opts, :nesterov, false) |
161 |
| - {momentum, opts} = Keyword.pop(opts, :momentum, nil) |
162 |
| - |
163 |
| - combinator = |
164 |
| - if centered do |
165 |
| - Updates.scale_by_stddev(opts) |
166 |
| - else |
167 |
| - Updates.scale_by_rms(opts) |
168 |
| - end |
169 |
| - |> scale_by_learning_rate(learning_rate) |
170 |
| - |
171 |
| - if momentum, |
172 |
| - do: Updates.trace(combinator, decay: momentum, nesterov: nesterov?), |
173 |
| - else: combinator |
| 41 | + Polaris.Optimizers.rmsprop([learning_rate: learning_rate] ++ opts) |
174 | 42 | end
|
175 | 43 |
|
176 |
| - @doc """ |
177 |
| - SGD optimizer. |
178 |
| -
|
179 |
| - ## Options |
180 |
| -
|
181 |
| - * `:momentum` - momentum term. If set, uses SGD with momentum and decay set |
182 |
| - to value of this term. |
183 |
| - * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` |
184 |
| - """ |
185 | 44 | @deprecated "Use Polaris.Optimizers.sgd/1 instead"
|
186 | 45 | def sgd(learning_rate \\ 1.0e-2, opts \\ []) do
|
187 |
| - momentum = opts[:momentum] |
188 |
| - nesterov? = opts[:nesterov] || false |
189 |
| - |
190 |
| - if momentum do |
191 |
| - Updates.trace(decay: momentum, nesterov: nesterov?) |
192 |
| - |> scale_by_learning_rate(learning_rate) |
193 |
| - else |
194 |
| - scale_by_learning_rate(learning_rate) |
195 |
| - end |
| 46 | + Polaris.Optimizers.sgd([learning_rate: learning_rate] ++ opts) |
196 | 47 | end
|
197 | 48 |
|
198 |
| - @doc """ |
199 |
| - Yogi optimizer. |
200 |
| -
|
201 |
| - ## Options |
202 |
| -
|
203 |
| - * `:initial_accumulator_value` - initial value for first and second moment. Defaults to `0.0` |
204 |
| - * `:b1` - first moment decay. Defaults to `0.9` |
205 |
| - * `:b2` - second moment decay. Defaults to `0.999` |
206 |
| - * `:eps` - numerical stability term. Defaults to `1.0e-8` |
207 |
| - * `:eps_root` - numerical stability term. Defaults to `0.0` |
208 |
| -
|
209 |
| - ## References |
210 |
| -
|
211 |
| - * [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) |
212 |
| - """ |
213 | 49 | @deprecated "Use Polaris.Optimizers.yogi/1 instead"
|
214 | 50 | def yogi(learning_rate \\ 1.0e-2, opts \\ []) do
|
215 |
| - Updates.scale_by_yogi(opts) |
216 |
| - |> scale_by_learning_rate(learning_rate) |
217 |
| - end |
218 |
| - |
219 |
| - ## Helpers |
220 |
| - |
221 |
| - defp scale_by_learning_rate(combinator \\ Updates.identity(), lr) |
222 |
| - |
223 |
| - defp scale_by_learning_rate(combinator, schedule) when is_function(schedule, 1) do |
224 |
| - Updates.scale_by_schedule(combinator, fn count -> Nx.negate(schedule.(count)) end) |
225 |
| - end |
226 |
| - |
227 |
| - defp scale_by_learning_rate(combinator, lr) do |
228 |
| - Updates.scale_by_state(combinator, -lr) |
| 51 | + Polaris.Optimizers.yogi([learning_rate: learning_rate] ++ opts) |
229 | 52 | end
|
230 | 53 | end
|
0 commit comments