diff --git a/examples/models/checkpoint.py b/examples/models/checkpoint.py index ee3fb560429..c84a689b951 100644 --- a/examples/models/checkpoint.py +++ b/examples/models/checkpoint.py @@ -64,7 +64,7 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: mismatched_dtypes = [ (key, value.dtype) for key, value in checkpoint.items() - if value.dtype != dtype + if hasattr(value, "dtype") and value.dtype != dtype ] if len(mismatched_dtypes) > 0: print(