Skip to content

Commit c10c677

Browse files
committed
do not fail on unexpected return value of load_state_dict
1 parent c12b990 commit c10c677

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

bioimageio/core/backends/pytorch_backend.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,23 @@ def load_torch_state_dict(
158158

159159
incompatible = model.load_state_dict(state)
160160
if (
161-
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
162-
and incompatible.missing_keys
161+
isinstance(incompatible, tuple)
162+
and hasattr(incompatible, "missing_keys")
163+
and hasattr(incompatible, "unexpected_keys")
163164
):
164-
logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
165+
if incompatible.missing_keys:
166+
logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
165167

166-
if (
167-
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
168-
and incompatible.unexpected_keys
169-
):
170-
logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys)
168+
if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys:
169+
logger.warning(
170+
"Unexpected state dict keys: {}", incompatible.unexpected_keys
171+
)
172+
else:
173+
logger.warning(
174+
"`model.load_state_dict()` unexpectedly returned: {} "
175+
+ "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)",
176+
(s[:20] + "..." if len(s := str(incompatible)) > 20 else s),
177+
)
171178

172179
return model
173180

0 commit comments

Comments
 (0)