Skip to content

Commit 57cd12f

Browse files
authored
Fix missing state in training (#579)
* Inspect * Inspect more * Raise on key * inspect * Fix? * Again * Again * Fix axon * It works
1 parent efd4c1f commit 57cd12f

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

lib/axon/loop.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1556,7 +1556,7 @@ defmodule Axon.Loop do
15561556
is set, the loop will raise on any cache miss during the training loop. Defaults
15571557
to true.
15581558
1559-
* `:force_garbage_collect?` - whether or not to force garbage collection after each
1559+
* `:force_garbage_collection?` - whether or not to force garbage collection after each
15601560
iteration. This may help avoid OOMs when training large models, but it will slow
15611561
training down.
15621562

lib/axon/model_state.ex

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,27 @@ defmodule Axon.ModelState do
222222
end
223223

224224
defp tree_get(data, access) when is_list(access) do
225-
Enum.reduce(access, %{}, &Map.put(&2, &1, Map.fetch!(data, &1)))
225+
Enum.reduce(access, %{}, fn key, acc ->
226+
case data do
227+
%{^key => val} ->
228+
Map.put(acc, key, val)
229+
230+
%{} ->
231+
acc
232+
end
233+
end)
226234
end
227235

228236
defp tree_get(data, access) when is_map(access) do
229237
Enum.reduce(access, %{}, fn {key, value}, acc ->
230-
tree = tree_get(data[key], value)
231-
Map.put(acc, key, tree)
238+
case data do
239+
%{^key => val} ->
240+
tree = tree_get(val, value)
241+
Map.put(acc, key, tree)
242+
243+
%{} ->
244+
acc
245+
end
232246
end)
233247
end
234248

0 commit comments

Comments
 (0)