Skip to content

Commit 7a2e9bc

Browse files
committed
Rename garbage collect option
1 parent 52be5dc commit 7a2e9bc

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

lib/axon/loop.ex

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,7 +1590,7 @@ defmodule Axon.Loop do
15901590
functions. JIT compilation must be used for gradient computations. Defaults
15911591
to true.
15921592
1593-
* `:force_garbage_collect?` - whether or not to force garbage collection after
1593+
* `:garbage_collect` - whether or not to garbage collect after
15941594
each loop iteration. This may prevent OOMs, but it will slow down training.
15951595
15961596
* `:strict?` - whether or not to compile step functions strictly. If this flag
@@ -1608,7 +1608,7 @@ defmodule Axon.Loop do
16081608
{max_epochs, opts} = Keyword.pop(opts, :epochs, 1)
16091609
{max_iterations, opts} = Keyword.pop(opts, :iterations, -1)
16101610
{jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true)
1611-
{force_garbage_collection?, opts} = Keyword.pop(opts, :force_garbage_collection?, false)
1611+
{garbage_collect, opts} = Keyword.pop(opts, :garbage_collect, false)
16121612
{strict?, jit_opts} = Keyword.pop(opts, :strict?, true)
16131613
debug? = Keyword.get(jit_opts, :debug, false)
16141614

@@ -1680,8 +1680,8 @@ defmodule Axon.Loop do
16801680
batch_fn =
16811681
{:non_compiled, build_batch_fn(step_fn, metric_fns), jit_compile?, strict?, jit_opts}
16821682

1683-
epoch_start..epoch_end//1
1684-
|> Enum.reduce_while(
1683+
Enum.reduce_while(
1684+
epoch_start..epoch_end//1,
16851685
{batch_fn, final_metrics_map, state},
16861686
fn epoch, {batch_fn, final_metrics_map, loop_state} ->
16871687
case fire_event(:epoch_started, handler_fns, loop_state, debug?) do
@@ -1697,7 +1697,14 @@ defmodule Axon.Loop do
16971697
end
16981698

16991699
{time, status_batch_fn_and_state} =
1700-
:timer.tc(&run_epoch/6, [batch_fn, handler_fns, state, data, debug?, force_garbage_collection?])
1700+
:timer.tc(&run_epoch/6, [
1701+
batch_fn,
1702+
handler_fns,
1703+
state,
1704+
data,
1705+
debug?,
1706+
garbage_collect
1707+
])
17011708

17021709
if debug? do
17031710
Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms")
@@ -1784,7 +1791,7 @@ defmodule Axon.Loop do
17841791
end
17851792
end
17861793

1787-
defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, force_garbage_collection?) do
1794+
defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, garbage_collect) do
17881795
Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, {_, batch_fn, state} ->
17891796
case fire_event(:iteration_started, handler_fns, state, debug?) do
17901797
{:halt_epoch, state} ->
@@ -1841,7 +1848,7 @@ defmodule Axon.Loop do
18411848
{:halt, {:halt_loop, batch_fn, state}}
18421849

18431850
{:continue, state} ->
1844-
if force_garbage_collection? do
1851+
if garbage_collect do
18451852
:erlang.garbage_collect()
18461853
end
18471854

0 commit comments

Comments
 (0)