Skip to content

Commit f3d5bf9

Browse files
authored
Fix iteration counts (#572)
1 parent 88c7823 commit f3d5bf9

File tree

3 files changed

+86
-27
lines changed

3 files changed

+86
-27
lines changed

lib/axon/loop.ex

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ defmodule Axon.Loop do
690690
loop
691691
|> log(&supervised_log_message_fn/1,
692692
event: :iteration_completed,
693-
filter: [every: log_interval]
693+
filter: [every: {:epoch, log_interval}]
694694
)
695695
|> log(fn _ -> "\n" end, event: :epoch_completed)
696696
else
@@ -1912,6 +1912,29 @@ defmodule Axon.Loop do
19121912
end)
19131913
end
19141914

1915+
defp update_counts(%State{event_counts: event_counts} = state, event)
1916+
when event in [:iteration_started, :iteration_completed] do
1917+
updated_counts =
1918+
Map.update(event_counts, event, %{total: 1, epoch: 1}, fn total_and_epoch ->
1919+
total_and_epoch
1920+
|> Map.update!(:total, &(&1 + 1))
1921+
|> Map.update!(:epoch, &(&1 + 1))
1922+
end)
1923+
1924+
%{state | event_counts: updated_counts}
1925+
end
1926+
1927+
defp update_counts(%State{event_counts: event_counts} = state, event)
1928+
when event in [:epoch_halted, :epoch_completed] do
1929+
updated_counts =
1930+
event_counts
1931+
|> Map.update(:iteration_started, %{total: 0, epoch: 0}, &%{&1 | epoch: 0})
1932+
|> Map.update(:iteration_completed, %{total: 0, epoch: 0}, &%{&1 | epoch: 0})
1933+
|> Map.update(event, 1, &(&1 + 1))
1934+
1935+
%{state | event_counts: updated_counts}
1936+
end
1937+
19151938
defp update_counts(%State{event_counts: event_counts} = state, event) do
19161939
%{state | event_counts: Map.update(event_counts, event, 1, fn x -> x + 1 end)}
19171940
end
@@ -2165,29 +2188,53 @@ defmodule Axon.Loop do
21652188

21662189
:first ->
21672190
fn %State{event_counts: counts}, event ->
2168-
counts[event] == 1
2191+
case counts[event] do
2192+
1 -> true
2193+
%{total: 1} -> true
2194+
_ -> false
2195+
end
21692196
end
21702197

21712198
filters when is_list(filters) ->
21722199
Enum.reduce(filters, fn _, _ -> true end, fn
2200+
{:every, {key, n}}, acc ->
2201+
fn state, event ->
2202+
acc.(state, event) and filter_every_n(state, event, key, n)
2203+
end
2204+
21732205
{:every, n}, acc ->
21742206
fn state, event ->
2175-
acc.(state, event) and filter_every_n(state, event, n)
2207+
acc.(state, event) and filter_every_n(state, event, :total, n)
2208+
end
2209+
2210+
{:before, {key, n}}, acc ->
2211+
fn state, event ->
2212+
acc.(state, event) and filter_before_n(state, event, key, n)
21762213
end
21772214

21782215
{:before, n}, acc ->
21792216
fn state, event ->
2180-
acc.(state, event) and filter_before_n(state, event, n)
2217+
acc.(state, event) and filter_before_n(state, event, :total, n)
2218+
end
2219+
2220+
{:after, {key, n}}, acc ->
2221+
fn state, event ->
2222+
acc.(state, event) and filter_after_n(state, event, key, n)
21812223
end
21822224

21832225
{:after, n}, acc ->
21842226
fn state, event ->
2185-
acc.(state, event) and filter_after_n(state, event, n)
2227+
acc.(state, event) and filter_after_n(state, event, :total, n)
2228+
end
2229+
2230+
{:once, {key, n}}, acc ->
2231+
fn state, event ->
2232+
acc.(state, event) and filter_once_n(state, event, key, n)
21862233
end
21872234

21882235
{:once, n}, acc ->
21892236
fn state, event ->
2190-
acc.(state, event) and filter_once_n(state, event, n)
2237+
acc.(state, event) and filter_once_n(state, event, :total, n)
21912238
end
21922239
end)
21932240

@@ -2204,20 +2251,31 @@ defmodule Axon.Loop do
22042251
end
22052252
end
22062253

2207-
defp filter_every_n(%State{event_counts: counts}, event, n) do
2208-
rem(counts[event] - 1, n) == 0
2254+
defp filter_every_n(%State{event_counts: counts}, event, key, n) do
2255+
count = get_count(counts, event, key)
2256+
rem(count - 1, n) == 0
22092257
end
22102258

2211-
defp filter_after_n(%State{event_counts: counts}, event, n) do
2212-
counts[event] > n
2259+
defp filter_after_n(%State{event_counts: counts}, event, key, n) do
2260+
count = get_count(counts, event, key)
2261+
count > n
22132262
end
22142263

2215-
defp filter_before_n(%State{event_counts: counts}, event, n) do
2216-
counts[event] < n
2264+
defp filter_before_n(%State{event_counts: counts}, event, key, n) do
2265+
count = get_count(counts, event, key)
2266+
count < n
22172267
end
22182268

2219-
defp filter_once_n(%State{event_counts: counts}, event, n) do
2220-
counts[event] == n
2269+
defp filter_once_n(%State{event_counts: counts}, event, key, n) do
2270+
count = get_count(counts, event, key)
2271+
count == n
2272+
end
2273+
2274+
defp get_count(counts, event, key) do
2275+
case counts[event] do
2276+
%{^key => count} -> count
2277+
count -> count
2278+
end
22212279
end
22222280

22232281
# JIT-compiles the given function if jit_compile? is true

lib/axon/loop/state.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ defmodule Axon.Loop.State do
6060
event_counts: %{
6161
started: 0,
6262
epoch_started: 0,
63-
iteration_started: 0,
64-
iteration_completed: 0,
63+
iteration_started: %{total: 0, epoch: 0},
64+
iteration_completed: %{total: 0, epoch: 0},
6565
epoch_completed: 0,
6666
epoch_halted: 0,
6767
halted: 0,

test/axon/loop_test.exs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -636,26 +636,26 @@ defmodule Axon.LoopTest do
636636
started: 1,
637637
epoch_started: 1,
638638
epoch_completed: 1,
639-
iteration_started: 10,
640-
iteration_completed: 10
639+
iteration_started: %{total: 10, epoch: 0},
640+
iteration_completed: %{total: 10, epoch: 0}
641641
}}
642642

643643
assert_received {:epoch_started,
644644
%{
645645
started: 1,
646646
epoch_started: 2,
647647
epoch_completed: 1,
648-
iteration_started: 10,
649-
iteration_completed: 10
648+
iteration_started: %{total: 10, epoch: 0},
649+
iteration_completed: %{total: 10, epoch: 0}
650650
}}
651651

652652
assert_received {:epoch_completed,
653653
%{
654654
started: 1,
655655
epoch_started: 2,
656656
epoch_completed: 2,
657-
iteration_started: 20,
658-
iteration_completed: 20
657+
iteration_started: %{total: 20, epoch: 0},
658+
iteration_completed: %{total: 20, epoch: 0}
659659
}}
660660

661661
refute_received _
@@ -786,7 +786,7 @@ defmodule Axon.LoopTest do
786786

787787
test "supports function filter" do
788788
fun = fn
789-
%{event_counts: counts}, event -> counts[event] == 5
789+
%{event_counts: counts}, event -> counts[event][:total] == 5
790790
end
791791

792792
run_dummy_loop!(:iteration_started, fun, 5, 10)
@@ -854,18 +854,19 @@ defmodule Axon.LoopTest do
854854
test "saves a checkpoint on custom events", %{loop: loop} do
855855
data = List.duplicate({Nx.iota({1, 1}), Nx.iota({1, 1})}, 5)
856856

857-
assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: 15}} =
857+
assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: %{total: 15}}} =
858858
loop
859859
|> Map.put(:output_transform, & &1)
860-
|> Loop.checkpoint(event: :iteration_completed, filter: [every: 2])
860+
|> Loop.checkpoint(event: :iteration_completed, filter: [every: {:epoch, 2}])
861861
|> Loop.run(data, Axon.ModelState.empty(), epochs: 3)
862862

863863
assert [
864864
"checkpoint_0_0.ckpt",
865865
"checkpoint_0_2.ckpt",
866866
"checkpoint_0_4.ckpt",
867-
"checkpoint_1_1.ckpt",
868-
"checkpoint_1_3.ckpt",
867+
"checkpoint_1_0.ckpt",
868+
"checkpoint_1_2.ckpt",
869+
"checkpoint_1_4.ckpt",
869870
"checkpoint_2_0.ckpt",
870871
"checkpoint_2_2.ckpt",
871872
"checkpoint_2_4.ckpt"

0 commit comments

Comments
 (0)