Skip to content

Commit 1857e49

Browse files
Fix warnings (#560)
1 parent 7e0e593 commit 1857e49

File tree

10 files changed

+82
-4077
lines changed

10 files changed

+82
-4077
lines changed

lib/axon/defn.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@ defmodule Axon.Defn do
2222

2323
@impl true
2424
def __partitions_options__(_), do: raise("not implemented")
25+
26+
@impl true
27+
def __to_backend__(_), do: raise("not implemented")
2528
end

lib/axon/optimizers.ex

Lines changed: 10 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -1,230 +1,53 @@
11
defmodule Axon.Optimizers do
22
@moduledoc false
3-
alias Polaris.Updates
43

5-
@doc """
6-
Adabelief optimizer.
7-
8-
## Options
9-
10-
* `:b1` - first moment decay. Defaults to `0.9`
11-
* `:b2` - second moment decay. Defaults to `0.999`
12-
* `:eps` - numerical stability term. Defaults to `0.0`
13-
* `:eps_root` - numerical stability term. Defaults to `1.0e-16`
14-
15-
## References
16-
17-
* [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468)
18-
"""
194
@deprecated "Use Polaris.Optimizers.adabelief/1 instead"
205
def adabelief(learning_rate \\ 1.0e-3, opts \\ []) do
21-
Updates.scale_by_belief(opts)
22-
|> scale_by_learning_rate(learning_rate)
6+
Polaris.Optimizers.adabelief([learning_rate: learning_rate] ++ opts)
237
end
248

25-
@doc """
26-
Adagrad optimizer.
27-
28-
## Options
29-
30-
* `:eps` - numerical stability term. Defaults to `1.0e-7`
31-
32-
## References
33-
34-
* [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
35-
"""
369
@deprecated "Use Polaris.Optimizers.adagrad/1 instead"
3710
def adagrad(learning_rate \\ 1.0e-3, opts \\ []) do
38-
Updates.scale_by_rss(opts)
39-
|> scale_by_learning_rate(learning_rate)
11+
Polaris.Optimizers.adagrad([learning_rate: learning_rate] ++ opts)
4012
end
4113

42-
@doc """
43-
Adam optimizer.
44-
45-
## Options
46-
47-
* `:b1` - first moment decay. Defaults to `0.9`
48-
* `:b2` - second moment decay. Defaults to `0.999`
49-
* `:eps` - numerical stability term. Defaults to `1.0e-8`
50-
* `:eps_root` - numerical stability term. Defaults to `1.0e-15`
51-
52-
## References
53-
54-
* [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
55-
"""
5614
@deprecated "Use Polaris.Optimizers.adam/1 instead"
5715
def adam(learning_rate \\ 1.0e-3, opts \\ []) do
58-
Updates.scale_by_adam(opts)
59-
|> scale_by_learning_rate(learning_rate)
16+
Polaris.Optimizers.adam([learning_rate: learning_rate] ++ opts)
6017
end
6118

62-
@doc """
63-
Adam with weight decay optimizer.
64-
65-
## Options
66-
67-
* `:b1` - first moment decay. Defaults to `0.9`
68-
* `:b2` - second moment decay. Defaults to `0.999`
69-
* `:eps` - numerical stability term. Defaults to `1.0e-8`
70-
* `:eps_root` - numerical stability term. Defaults to `0.0`
71-
* `:decay` - weight decay. Defaults to `0.0`
72-
"""
7319
@deprecated "Use Polaris.Optimizers.adamw/1 instead"
7420
def adamw(learning_rate \\ 1.0e-3, opts \\ []) do
75-
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
76-
77-
Updates.scale_by_adam(opts)
78-
|> Updates.add_decayed_weights(decay: decay)
79-
|> scale_by_learning_rate(learning_rate)
21+
Polaris.Optimizers.adamw([learning_rate: learning_rate] ++ opts)
8022
end
8123

82-
@doc """
83-
Lamb optimizer.
84-
85-
## Options
86-
87-
* `:b1` - first moment decay. Defaults to `0.9`
88-
* `:b2` - second moment decay. Defaults to `0.999`
89-
* `:eps` - numerical stability term. Defaults to `1.0e-8`
90-
* `:eps_root` - numerical stability term. Defaults to `0.0`
91-
* `:decay` - weight decay. Defaults to `0.0`
92-
* `:min_norm` - minimum norm value. Defaults to `0.0`
93-
94-
## References
95-
96-
* [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962)
97-
"""
9824
@deprecated "Use Polaris.Optimizers.lamb/1 instead"
9925
def lamb(learning_rate \\ 1.0e-2, opts \\ []) do
100-
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
101-
{min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0)
102-
103-
Updates.scale_by_adam(opts)
104-
|> Updates.add_decayed_weights(decay: decay)
105-
|> Updates.scale_by_trust_ratio(min_norm: min_norm)
106-
|> scale_by_learning_rate(learning_rate)
26+
Polaris.Optimizers.lamb([learning_rate: learning_rate] ++ opts)
10727
end
10828

109-
@doc """
110-
Noisy SGD optimizer.
111-
112-
## Options
113-
114-
* `:eta` - used to compute variance of noise distribution. Defaults to `0.1`
115-
* `:gamma` - used to compute variance of noise distribution. Defaults to `0.55`
116-
"""
11729
@deprecated "Use Polaris.Optimizers.noisy_sgd/1 instead"
11830
def noisy_sgd(learning_rate \\ 1.0e-2, opts \\ []) do
119-
scale_by_learning_rate(learning_rate)
120-
|> Updates.add_noise(opts)
31+
Polaris.Optimizers.noisy_sgd([learning_rate: learning_rate] ++ opts)
12132
end
12233

123-
@doc """
124-
Rectified Adam optimizer.
125-
126-
## Options
127-
128-
* `:b1` - first moment decay. Defaults to `0.9`
129-
* `:b2` - second moment decay. Defaults to `0.999`
130-
* `:eps` - numerical stability term. Defaults to `1.0e-8`
131-
* `:eps_root` - numerical stability term. Defaults to `0.0`
132-
* `:threshold` - threshold term. Defaults to `5.0`
133-
134-
## References
135-
136-
* [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf)
137-
"""
13834
@deprecated "Use Polaris.Optimizers.radam/1 instead"
13935
def radam(learning_rate \\ 1.0e-3, opts \\ []) do
140-
Updates.scale_by_radam(opts)
141-
|> scale_by_learning_rate(learning_rate)
36+
Polaris.Optimizers.radam([learning_rate: learning_rate] ++ opts)
14237
end
14338

144-
@doc """
145-
RMSProp optimizer.
146-
147-
## Options
148-
149-
* `:centered` - whether to scale by centered root of EMA of squares. Defaults to `false`
150-
* `:momentum` - momentum term. If set, uses SGD with momentum and decay set
151-
to value of this term.
152-
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
153-
* `:initial_scale` - initial value of EMA. Defaults to `0.0`
154-
* `:decay` - EMA decay rate. Defaults to `0.9`
155-
* `:eps` - numerical stability term. Defaults to `1.0e-8`
156-
"""
15739
@deprecated "Use Polaris.Optimizers.rmsprop/1 instead"
15840
def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do
159-
{centered, opts} = Keyword.pop(opts, :centered, false)
160-
{nesterov?, opts} = Keyword.pop(opts, :nesterov, false)
161-
{momentum, opts} = Keyword.pop(opts, :momentum, nil)
162-
163-
combinator =
164-
if centered do
165-
Updates.scale_by_stddev(opts)
166-
else
167-
Updates.scale_by_rms(opts)
168-
end
169-
|> scale_by_learning_rate(learning_rate)
170-
171-
if momentum,
172-
do: Updates.trace(combinator, decay: momentum, nesterov: nesterov?),
173-
else: combinator
41+
Polaris.Optimizers.rmsprop([learning_rate: learning_rate] ++ opts)
17442
end
17543

176-
@doc """
177-
SGD optimizer.
178-
179-
## Options
180-
181-
* `:momentum` - momentum term. If set, uses SGD with momentum and decay set
182-
to value of this term.
183-
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
184-
"""
18544
@deprecated "Use Polaris.Optimizers.sgd/1 instead"
18645
def sgd(learning_rate \\ 1.0e-2, opts \\ []) do
187-
momentum = opts[:momentum]
188-
nesterov? = opts[:nesterov] || false
189-
190-
if momentum do
191-
Updates.trace(decay: momentum, nesterov: nesterov?)
192-
|> scale_by_learning_rate(learning_rate)
193-
else
194-
scale_by_learning_rate(learning_rate)
195-
end
46+
Polaris.Optimizers.sgd([learning_rate: learning_rate] ++ opts)
19647
end
19748

198-
@doc """
199-
Yogi optimizer.
200-
201-
## Options
202-
203-
* `:initial_accumulator_value` - initial value for first and second moment. Defaults to `0.0`
204-
* `:b1` - first moment decay. Defaults to `0.9`
205-
* `:b2` - second moment decay. Defaults to `0.999`
206-
* `:eps` - numerical stability term. Defaults to `1.0e-8`
207-
* `:eps_root` - numerical stability term. Defaults to `0.0`
208-
209-
## References
210-
211-
* [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf)
212-
"""
21349
@deprecated "Use Polaris.Optimizers.yogi/1 instead"
21450
def yogi(learning_rate \\ 1.0e-2, opts \\ []) do
215-
Updates.scale_by_yogi(opts)
216-
|> scale_by_learning_rate(learning_rate)
217-
end
218-
219-
## Helpers
220-
221-
defp scale_by_learning_rate(combinator \\ Updates.identity(), lr)
222-
223-
defp scale_by_learning_rate(combinator, schedule) when is_function(schedule, 1) do
224-
Updates.scale_by_schedule(combinator, fn count -> Nx.negate(schedule.(count)) end)
225-
end
226-
227-
defp scale_by_learning_rate(combinator, lr) do
228-
Updates.scale_by_state(combinator, -lr)
51+
Polaris.Optimizers.yogi([learning_rate: learning_rate] ++ opts)
22952
end
23053
end

0 commit comments

Comments
 (0)