@@ -534,8 +534,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
534534 Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
535535 dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
536536 `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
537- unspecified components with `_ ` (for example `{'transformer': torch.bfloat16, '_ ': torch.float16}`). If
538- a component is not specifed and no default is set, `torch.float32` is used.
537+ unspecified components with `default ` (for example `{'transformer': torch.bfloat16, 'default ': torch.float16}`).
538+ If a component is not specified and no default is set, `torch.float32` is used.
539539 custom_pipeline (`str`, *optional*):
540540
541541 <Tip warning={true}>
@@ -858,6 +858,20 @@ def load_module(name, value):
858858 f"Expected types for { key } : { _expected_class_types } , got { class_obj .__class__ .__name__ } ."
859859 )
860860
861+ # Check `torch_dtype` map for unused keys
862+ if isinstance (torch_dtype , dict ):
863+ extra_keys_dtype = set (torch_dtype .keys ()) - set (passed_class_obj .keys ())
864+ extra_keys_obj = set (passed_class_obj .keys ()) - set (torch_dtype .keys ())
865+ if len (extra_keys_dtype ) > 0 :
866+ logger .warning (
867+ f"Expected `{ list (passed_class_obj .keys ())} `, got extra `torch_dtype` keys `{ extra_keys_dtype } `."
868+ )
869+ if len (extra_keys_obj ) > 0 :
870+ logger .warning (
871+ f"Expected `{ list (passed_class_obj .keys ())} `, missing `torch_dtype` keys `{ extra_keys_dtype } `."
872+ " using `default` or `torch.float32`."
873+ )
874+
861875 # Special case: safety_checker must be loaded separately when using `from_flax`
862876 if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj :
863877 raise NotImplementedError (
@@ -925,7 +939,7 @@ def load_module(name, value):
925939 else :
926940 # load sub model
927941 sub_model_dtype = (
928- torch_dtype .get (name , torch_dtype .get ("_ " , torch .float32 ))
942+ torch_dtype .get (name , torch_dtype .get ("default " , torch .float32 ))
929943 if isinstance (torch_dtype , dict )
930944 else torch_dtype
931945 )
0 commit comments