@@ -1590,7 +1590,7 @@ defmodule Axon.Loop do
1590
1590
functions. JIT compilation must be used for gradient computations. Defaults
1591
1591
to true.
1592
1592
1593
- * `:force_garbage_collect? ` - whether or not to force garbage collection after
1593
+ * `:garbage_collect ` - whether or not to garbage collect after
1594
1594
each loop iteration. This may prevent OOMs, but it will slow down training.
1595
1595
1596
1596
* `:strict?` - whether or not to compile step functions strictly. If this flag
@@ -1608,7 +1608,7 @@ defmodule Axon.Loop do
1608
1608
{ max_epochs , opts } = Keyword . pop ( opts , :epochs , 1 )
1609
1609
{ max_iterations , opts } = Keyword . pop ( opts , :iterations , - 1 )
1610
1610
{ jit_compile? , opts } = Keyword . pop ( opts , :jit_compile? , true )
1611
- { force_garbage_collection? , opts } = Keyword . pop ( opts , :force_garbage_collection? , false )
1611
+ { garbage_collect , opts } = Keyword . pop ( opts , :garbage_collect , false )
1612
1612
{ strict? , jit_opts } = Keyword . pop ( opts , :strict? , true )
1613
1613
debug? = Keyword . get ( jit_opts , :debug , false )
1614
1614
@@ -1680,8 +1680,8 @@ defmodule Axon.Loop do
1680
1680
batch_fn =
1681
1681
{ :non_compiled , build_batch_fn ( step_fn , metric_fns ) , jit_compile? , strict? , jit_opts }
1682
1682
1683
- epoch_start .. epoch_end // 1
1684
- |> Enum . reduce_while (
1683
+ Enum . reduce_while (
1684
+ epoch_start .. epoch_end // 1 ,
1685
1685
{ batch_fn , final_metrics_map , state } ,
1686
1686
fn epoch , { batch_fn , final_metrics_map , loop_state } ->
1687
1687
case fire_event ( :epoch_started , handler_fns , loop_state , debug? ) do
@@ -1697,7 +1697,14 @@ defmodule Axon.Loop do
1697
1697
end
1698
1698
1699
1699
{ time , status_batch_fn_and_state } =
1700
- :timer . tc ( & run_epoch / 6 , [ batch_fn , handler_fns , state , data , debug? , force_garbage_collection? ] )
1700
+ :timer . tc ( & run_epoch / 6 , [
1701
+ batch_fn ,
1702
+ handler_fns ,
1703
+ state ,
1704
+ data ,
1705
+ debug? ,
1706
+ garbage_collect
1707
+ ] )
1701
1708
1702
1709
if debug? do
1703
1710
Logger . debug ( "Axon.Loop finished running epoch in #{ us_to_ms ( time ) } ms" )
@@ -1784,7 +1791,7 @@ defmodule Axon.Loop do
1784
1791
end
1785
1792
end
1786
1793
1787
- defp run_epoch ( batch_fn , handler_fns , loop_state , data , debug? , force_garbage_collection? ) do
1794
+ defp run_epoch ( batch_fn , handler_fns , loop_state , data , debug? , garbage_collect ) do
1788
1795
Enum . reduce_while ( data , { :continue , batch_fn , loop_state } , fn data , { _ , batch_fn , state } ->
1789
1796
case fire_event ( :iteration_started , handler_fns , state , debug? ) do
1790
1797
{ :halt_epoch , state } ->
@@ -1841,7 +1848,7 @@ defmodule Axon.Loop do
1841
1848
{ :halt , { :halt_loop , batch_fn , state } }
1842
1849
1843
1850
{ :continue , state } ->
1844
- if force_garbage_collection? do
1851
+ if garbage_collect do
1845
1852
:erlang . garbage_collect ( )
1846
1853
end
1847
1854
0 commit comments