Skip to content

Commit 9cf25bd

Browse files
authored
Allow none tensor checkpoint values
1 parent 38384a2 commit 9cf25bd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/models/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
6464
mismatched_dtypes = [
6565
(key, value.dtype)
6666
for key, value in checkpoint.items()
67-
if value.dtype != dtype
67+
if hasattr(value, 'dtype') and value.dtype != dtype
6868
]
6969
if len(mismatched_dtypes) > 0:
7070
print(

0 commit comments

Comments
 (0)