Skip to content

Commit 4851084

Browse files
committed
Remove gradient accumulation for now
1 parent f3d5bf9 commit 4851084

File tree

3 files changed

+54
-28
lines changed

3 files changed

+54
-28
lines changed

examples/generative/text_generator.exs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Based on https://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/
22
Mix.install([
3-
{:axon, "~> 0.5"},
4-
{:nx, "~> 0.5"},
5-
{:exla, "~> 0.5"},
3+
{:axon, path: "/Users/sean/projects/axon"},
4+
{:nx, "~> 0.7"},
5+
{:exla, "~> 0.7"},
66
{:req, "~> 0.3.3"}
77
])
88

lib/axon.ex

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,26 @@ defmodule Axon do
339339
name = name(op_name, name)
340340

341341
id = System.unique_integer([:positive, :monotonic])
342-
axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, meta, opts, global_options)
342+
343+
axon_node =
344+
make_node(id, op, name, op_name, mode, inputs, params, args, meta, opts, global_options)
345+
343346
%Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)}
344347
end
345348

346-
defp make_node(id, op, name, op_name, mode, inputs, params, args, meta, layer_opts, global_options) do
349+
defp make_node(
350+
id,
351+
op,
352+
name,
353+
op_name,
354+
mode,
355+
inputs,
356+
params,
357+
args,
358+
meta,
359+
layer_opts,
360+
global_options
361+
) do
347362
{:current_stacktrace, [_process_info, _axon_layer | stacktrace]} =
348363
Process.info(self(), :current_stacktrace)
349364

@@ -469,7 +484,14 @@ defmodule Axon do
469484
input_shape = opts[:shape]
470485

471486
output_shape = input_shape && Axon.Shape.input(input_shape)
472-
layer(:input, [], name: name, shape: output_shape, meta: meta, op_name: :input, optional: optional)
487+
488+
layer(:input, [],
489+
name: name,
490+
shape: output_shape,
491+
meta: meta,
492+
op_name: :input,
493+
optional: optional
494+
)
473495
end
474496

475497
@doc """
@@ -559,7 +581,12 @@ defmodule Axon do
559581
def constant(number, opts) when is_number(number) do
560582
opts = Keyword.validate!(opts, [:name, :meta])
561583

562-
layer(:constant, [], name: opts[:name], meta: opts[:meta], value: Nx.tensor(number), op_name: :constant)
584+
layer(:constant, [],
585+
name: opts[:name],
586+
meta: opts[:meta],
587+
value: Nx.tensor(number),
588+
op_name: :constant
589+
)
563590
end
564591

565592
def constant(value, _) do
@@ -2137,7 +2164,9 @@ defmodule Axon do
21372164
"""
21382165
@doc type: :shape
21392166
def resize(%Axon{} = x, resize_shape, opts \\ []) do
2140-
opts = Keyword.validate!(opts, [:name, :meta, method: :nearest, antialias: true, channels: :last])
2167+
opts =
2168+
Keyword.validate!(opts, [:name, :meta, method: :nearest, antialias: true, channels: :last])
2169+
21412170
channels = opts[:channels]
21422171

21432172
layer(:resize, [x],
@@ -2384,7 +2413,12 @@ defmodule Axon do
23842413
Nx.equal(Nx.as_type(x, :s64), opts[:eos_token])
23852414
end
23862415

2387-
layer(fun, [input], eos_token: eos_token, op_name: :mask, meta: opts[:meta], name: opts[:name])
2416+
layer(fun, [input],
2417+
eos_token: eos_token,
2418+
op_name: :mask,
2419+
meta: opts[:meta],
2420+
name: opts[:name]
2421+
)
23882422
end
23892423

23902424
@doc """
@@ -3163,7 +3197,12 @@ defmodule Axon do
31633197
def stack_columns(%Axon{} = x, opts \\ []) do
31643198
opts = Keyword.validate!(opts, [:name, ignore: []])
31653199

3166-
layer(:stack_columns, [x], meta: opts[:meta], name: opts[:name], ignore: opts[:ignore], op_name: :stack_columns)
3200+
layer(:stack_columns, [x],
3201+
meta: opts[:meta],
3202+
name: opts[:name],
3203+
ignore: opts[:ignore],
3204+
op_name: :stack_columns
3205+
)
31673206
end
31683207

31693208
@doc """

lib/axon/loop.ex

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -322,16 +322,11 @@ defmodule Axon.Loop do
322322
* `:loss_scale` - type of loss-scaling to use, if any. Loss-scaling is necessary when
323323
doing mixed precision training for numerical stability. Defaults to `:identity` or
324324
no loss-scaling.
325-
326-
* `:gradient_accumulation_steps` - number of gradient accumulation steps to take during
327-
training. Gradient accumulation decreases the number of updates by accumulating gradients
328-
between steps, increasing the effective batch size on smaller devices. Defaults to 1.
329325
"""
330326
def train_step(model, loss, optimizer, opts \\ []) do
331-
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity, gradient_accumulation_steps: 1])
327+
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity])
332328

333329
loss_scale = opts[:loss_scale] || :identity
334-
gradient_accumulation_steps = opts[:gradient_accumulation_steps] || 1
335330

336331
{init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts)
337332
loss_fn = build_loss_fn(loss)
@@ -377,12 +372,8 @@ defmodule Axon.Loop do
377372
tar
378373
|> loss_fn.(model_out.prediction)
379374
|> then(fn loss ->
380-
scaled =
381-
loss
382-
|> scale_loss.(loss_scale_state)
383-
|> Nx.divide(gradient_accumulation_steps)
384-
385-
{scaled, Nx.divide(loss, gradient_accumulation_steps)}
375+
scaled = scale_loss.(loss, loss_scale_state)
376+
{scaled, loss}
386377
end)
387378

388379
{model_out, scaled_loss, unscaled_loss}
@@ -665,17 +656,13 @@ defmodule Axon.Loop do
665656
* `:loss_scale` - type of loss-scaling to use, if any. Loss-scaling is necessary when
666657
doing mixed precision training for numerical stability. Defaults to `:identity` or
667658
no loss-scaling.
668-
669-
* `:gradient_accumulation_steps` - number of gradient accumulation steps to take during
670-
training. Gradient accumulation decreases the number of updates by accumulating gradients
671-
between steps, increasing the effective batch size on smaller devices. Defaults to 1.
672659
"""
673660
def trainer(model, loss, optimizer, opts \\ []) do
674-
opts = Keyword.validate!(opts, [:seed, :loss_scale, :gradient_accumulation_steps, log: 50])
661+
opts = Keyword.validate!(opts, [:seed, :loss_scale, log: 50])
675662

676663
# Build loss now so we can use it as a metric
677664
loss_fn = build_loss_fn(loss)
678-
step_opts = Keyword.take(opts, [:gradient_accumulation_steps, :loss_scale, :seed])
665+
step_opts = Keyword.take(opts, [:loss_scale, :seed])
679666
{init_fn, step_fn} = train_step(model, loss_fn, optimizer, step_opts)
680667

681668
log_interval = opts[:log] || 50

0 commit comments

Comments
 (0)