Skip to content

Commit 2c58c64

Browse files
committed
Add warning, change _ to default
1 parent db77006 commit 2c58c64

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
555555

556556
else:
557557
sub_model_dtype = (
558-
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
558+
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
559559
if isinstance(torch_dtype, dict)
560560
else torch_dtype
561561
)
@@ -585,7 +585,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
585585
module_sizes = {
586586
module_name: compute_module_sizes(
587587
module,
588-
dtype=torch_dtype.get(module_name, torch_dtype.get("_", torch.float32))
588+
dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
589589
if isinstance(torch_dtype, dict)
590590
else torch_dtype,
591591
)[""]

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
534534
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
535535
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
536536
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
537-
unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`). If
538-
a component is not specifed and no default is set, `torch.float32` is used.
537+
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default': torch.float16}`).
538+
If a component is not specified and no default is set, `torch.float32` is used.
539539
custom_pipeline (`str`, *optional*):
540540
541541
<Tip warning={true}>
@@ -858,6 +858,20 @@ def load_module(name, value):
858858
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
859859
)
860860

861+
# Check `torch_dtype` map for unused keys
862+
if isinstance(torch_dtype, dict):
863+
extra_keys_dtype = set(torch_dtype.keys()) - set(passed_class_obj.keys())
864+
extra_keys_obj = set(passed_class_obj.keys()) - set(torch_dtype.keys())
865+
if len(extra_keys_dtype) > 0:
866+
logger.warning(
867+
f"Expected `{list(passed_class_obj.keys())}`, got extra `torch_dtype` keys `{extra_keys_dtype}`."
868+
)
869+
if len(extra_keys_obj) > 0:
870+
logger.warning(
871+
f"Expected `{list(passed_class_obj.keys())}`, missing `torch_dtype` keys `{extra_keys_dtype}`."
872+
" using `default` or `torch.float32`."
873+
)
874+
861875
# Special case: safety_checker must be loaded separately when using `from_flax`
862876
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
863877
raise NotImplementedError(
@@ -925,7 +939,7 @@ def load_module(name, value):
925939
else:
926940
# load sub model
927941
sub_model_dtype = (
928-
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
942+
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
929943
if isinstance(torch_dtype, dict)
930944
else torch_dtype
931945
)

0 commit comments

Comments
 (0)