Skip to content
14 changes: 12 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
loaded_sub_model = passed_class_obj[name]

else:
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like _ might be a bit unintuitive. Better to expose full dtype maps or in case partial ones are provided we default to torch.float32 for the rest of the components.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be default? Considering how it will work for integrations, instead of say {'transformer': torch.bfloat16, 'text_encoder': torch.float16, 'text_encoder_2': torch.float16, 'text_encoder_3': torch.float16} for SD3 and {'transformer': torch.bfloat16, 'text_encoder': torch.float16, 'text_encoder_2': torch.float16} for Flux. Not a big issue because components can be got from cls._get_signature_types().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no strong opinions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now it's renamed to default to be clearer, we can remove later if its not needed.

loaded_sub_model = _load_empty_model(
library_name=library_name,
class_name=class_name,
Expand All @@ -562,7 +567,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
name=name,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
cached_folder=kwargs.get("cached_folder", None),
force_download=kwargs.get("force_download", None),
proxies=kwargs.get("proxies", None),
Expand All @@ -578,7 +583,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
# Obtain a sorted dictionary for mapping the model-level components
# to their sizes.
module_sizes = {
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
module_name: compute_module_sizes(
module,
dtype=torch_dtype.get(module_name, torch_dtype.get("_", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype,
)[""]
for module_name, module in init_empty_modules.items()
if isinstance(module, torch.nn.Module)
}
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
saved using
[`~DiffusionPipeline.save_pretrained`].
torch_dtype (`str` or `torch.dtype`, *optional*):
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`). If
a component is not specifed and no default is set, `torch.float32` is used.
custom_pipeline (`str`, *optional*):

<Tip warning={true}>
Expand Down Expand Up @@ -921,14 +924,19 @@ def load_module(name, value):
loaded_sub_model = passed_class_obj[name]
else:
# load sub model
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
importable_classes=importable_classes,
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
provider=provider,
sess_options=sess_options,
device_map=current_device_map,
Expand Down
Loading