Skip to content

Commit 3643246

Browse files
committed
allow models to run with a user-provided dtype map instead of a single dtype
1 parent 1826a1e commit 3643246

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
554554
loaded_sub_model = passed_class_obj[name]
555555

556556
else:
557+
sub_model_dtype = (
558+
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
559+
if isinstance(torch_dtype, dict)
560+
else torch_dtype
561+
)
557562
loaded_sub_model = _load_empty_model(
558563
library_name=library_name,
559564
class_name=class_name,
@@ -562,7 +567,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
562567
is_pipeline_module=is_pipeline_module,
563568
pipeline_class=pipeline_class,
564569
name=name,
565-
torch_dtype=torch_dtype,
570+
torch_dtype=sub_model_dtype,
566571
cached_folder=kwargs.get("cached_folder", None),
567572
force_download=kwargs.get("force_download", None),
568573
proxies=kwargs.get("proxies", None),
@@ -578,7 +583,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
578583
# Obtain a sorted dictionary for mapping the model-level components
579584
# to their sizes.
580585
module_sizes = {
581-
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
586+
module_name: compute_module_sizes(
587+
module,
588+
dtype=torch_dtype.get(module_name, torch_dtype.get("_", torch.float32))
589+
if isinstance(torch_dtype, dict)
590+
else torch_dtype,
591+
)[""]
582592
for module_name, module in init_empty_modules.items()
583593
if isinstance(module, torch.nn.Module)
584594
}

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
530530
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
531531
saved using
532532
[`~DiffusionPipeline.save_pretrained`].
533-
torch_dtype (`str` or `torch.dtype`, *optional*):
533+
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
534534
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
535535
dtype is automatically derived from the model's weights.
536+
To load submodels with different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`).
537+
Set the default dtype for unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`).
538+
If a component is not specifed and no default is set, `torch.float32` is used.
536539
custom_pipeline (`str`, *optional*):
537540
538541
<Tip warning={true}>
@@ -921,14 +924,19 @@ def load_module(name, value):
921924
loaded_sub_model = passed_class_obj[name]
922925
else:
923926
# load sub model
927+
sub_model_dtype = (
928+
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
929+
if isinstance(torch_dtype, dict)
930+
else torch_dtype
931+
)
924932
loaded_sub_model = load_sub_model(
925933
library_name=library_name,
926934
class_name=class_name,
927935
importable_classes=importable_classes,
928936
pipelines=pipelines,
929937
is_pipeline_module=is_pipeline_module,
930938
pipeline_class=pipeline_class,
931-
torch_dtype=torch_dtype,
939+
torch_dtype=sub_model_dtype,
932940
provider=provider,
933941
sess_options=sess_options,
934942
device_map=current_device_map,

0 commit comments

Comments
 (0)