Skip to content

Commit 135fae7

Browse files
committed
Fix warning
1 parent 7a13000 commit 135fae7

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

lib/bumblebee/conversion/pytorch_params.ex

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,16 @@ defmodule Bumblebee.Conversion.PyTorchParams do
8484

8585
defp state_dict?(%{} = dict) when not is_struct(dict) do
8686
Enum.all?(dict, fn {key, value} ->
87-
is_binary(key) and Nx.LazyContainer.impl_for(value) != nil
87+
is_binary(key) and implements_lazy_container?(value)
8888
end)
8989
end
9090

9191
defp state_dict?(_other), do: false
9292

93+
defp implements_lazy_container?(value) do
94+
Nx.LazyContainer.impl_for(value) != Nx.LazyContainer.Any or Nx.Container.impl_for(value) != nil
95+
end
96+
9397
defp init_params(model, params_expr, pytorch_state, params_mapping) do
9498
layers =
9599
model

0 commit comments

Comments
 (0)