@@ -690,7 +690,7 @@ defmodule Axon.Loop do
690
690
loop
691
691
|> log ( & supervised_log_message_fn / 1 ,
692
692
event: :iteration_completed ,
693
- filter: [ every: log_interval ]
693
+ filter: [ every: { :epoch , log_interval } ]
694
694
)
695
695
|> log ( fn _ -> "\n " end , event: :epoch_completed )
696
696
else
@@ -1912,6 +1912,29 @@ defmodule Axon.Loop do
1912
1912
end )
1913
1913
end
1914
1914
1915
+ defp update_counts ( % State { event_counts: event_counts } = state , event )
1916
+ when event in [ :iteration_started , :iteration_completed ] do
1917
+ updated_counts =
1918
+ Map . update ( event_counts , event , % { total: 1 , epoch: 1 } , fn total_and_epoch ->
1919
+ total_and_epoch
1920
+ |> Map . update! ( :total , & ( & 1 + 1 ) )
1921
+ |> Map . update! ( :epoch , & ( & 1 + 1 ) )
1922
+ end )
1923
+
1924
+ % { state | event_counts: updated_counts }
1925
+ end
1926
+
1927
+ defp update_counts ( % State { event_counts: event_counts } = state , event )
1928
+ when event in [ :epoch_halted , :epoch_completed ] do
1929
+ updated_counts =
1930
+ event_counts
1931
+ |> Map . update ( :iteration_started , % { total: 0 , epoch: 0 } , & % { & 1 | epoch: 0 } )
1932
+ |> Map . update ( :iteration_completed , % { total: 0 , epoch: 0 } , & % { & 1 | epoch: 0 } )
1933
+ |> Map . update ( event , 1 , & ( & 1 + 1 ) )
1934
+
1935
+ % { state | event_counts: updated_counts }
1936
+ end
1937
+
1915
1938
defp update_counts ( % State { event_counts: event_counts } = state , event ) do
1916
1939
% { state | event_counts: Map . update ( event_counts , event , 1 , fn x -> x + 1 end ) }
1917
1940
end
@@ -2165,29 +2188,53 @@ defmodule Axon.Loop do
2165
2188
2166
2189
:first ->
2167
2190
fn % State { event_counts: counts } , event ->
2168
- counts [ event ] == 1
2191
+ case counts [ event ] do
2192
+ 1 -> true
2193
+ % { total: 1 } -> true
2194
+ _ -> false
2195
+ end
2169
2196
end
2170
2197
2171
2198
filters when is_list ( filters ) ->
2172
2199
Enum . reduce ( filters , fn _ , _ -> true end , fn
2200
+ { :every , { key , n } } , acc ->
2201
+ fn state , event ->
2202
+ acc . ( state , event ) and filter_every_n ( state , event , key , n )
2203
+ end
2204
+
2173
2205
{ :every , n } , acc ->
2174
2206
fn state , event ->
2175
- acc . ( state , event ) and filter_every_n ( state , event , n )
2207
+ acc . ( state , event ) and filter_every_n ( state , event , :total , n )
2208
+ end
2209
+
2210
+ { :before , { key , n } } , acc ->
2211
+ fn state , event ->
2212
+ acc . ( state , event ) and filter_before_n ( state , event , key , n )
2176
2213
end
2177
2214
2178
2215
{ :before , n } , acc ->
2179
2216
fn state , event ->
2180
- acc . ( state , event ) and filter_before_n ( state , event , n )
2217
+ acc . ( state , event ) and filter_before_n ( state , event , :total , n )
2218
+ end
2219
+
2220
+ { :after , { key , n } } , acc ->
2221
+ fn state , event ->
2222
+ acc . ( state , event ) and filter_after_n ( state , event , key , n )
2181
2223
end
2182
2224
2183
2225
{ :after , n } , acc ->
2184
2226
fn state , event ->
2185
- acc . ( state , event ) and filter_after_n ( state , event , n )
2227
+ acc . ( state , event ) and filter_after_n ( state , event , :total , n )
2228
+ end
2229
+
2230
+ { :once , { key , n } } , acc ->
2231
+ fn state , event ->
2232
+ acc . ( state , event ) and filter_once_n ( state , event , key , n )
2186
2233
end
2187
2234
2188
2235
{ :once , n } , acc ->
2189
2236
fn state , event ->
2190
- acc . ( state , event ) and filter_once_n ( state , event , n )
2237
+ acc . ( state , event ) and filter_once_n ( state , event , :total , n )
2191
2238
end
2192
2239
end )
2193
2240
@@ -2204,20 +2251,31 @@ defmodule Axon.Loop do
2204
2251
end
2205
2252
end
2206
2253
2207
- defp filter_every_n ( % State { event_counts: counts } , event , n ) do
2208
- rem ( counts [ event ] - 1 , n ) == 0
2254
+ defp filter_every_n ( % State { event_counts: counts } , event , key , n ) do
2255
+ count = get_count ( counts , event , key )
2256
+ rem ( count - 1 , n ) == 0
2209
2257
end
2210
2258
2211
- defp filter_after_n ( % State { event_counts: counts } , event , n ) do
2212
- counts [ event ] > n
2259
+ defp filter_after_n ( % State { event_counts: counts } , event , key , n ) do
2260
+ count = get_count ( counts , event , key )
2261
+ count > n
2213
2262
end
2214
2263
2215
- defp filter_before_n ( % State { event_counts: counts } , event , n ) do
2216
- counts [ event ] < n
2264
+ defp filter_before_n ( % State { event_counts: counts } , event , key , n ) do
2265
+ count = get_count ( counts , event , key )
2266
+ count < n
2217
2267
end
2218
2268
2219
- defp filter_once_n ( % State { event_counts: counts } , event , n ) do
2220
- counts [ event ] == n
2269
+ defp filter_once_n ( % State { event_counts: counts } , event , key , n ) do
2270
+ count = get_count ( counts , event , key )
2271
+ count == n
2272
+ end
2273
+
2274
+ defp get_count ( counts , event , key ) do
2275
+ case counts [ event ] do
2276
+ % { ^ key => count } -> count
2277
+ count -> count
2278
+ end
2221
2279
end
2222
2280
2223
2281
# JIT-compiles the given function if jit_compile? is true
0 commit comments