Skip to content

Commit 3e4a6bd

Browse files
sayakpaulSunMarcstevhliupcuenca
authored
[Core] add "balanced" device_map support to pipelines (#6857)
* get device <-> component mapping when using multiple gpus. * condition the device_map bits. * relax condition * device_map progress. * device_map enhancement * some cleaning up and debugging * Apply suggestions from code review Co-authored-by: Marc Sun <[email protected]> * incorporate suggestions from PR. * remove multi-gpu condition for now. * guard check the component -> device mapping * fix: device_memory variable * dispatching transformers model to have force_hooks=True * better guarding for transformers device_map * introduce support balanced_low_memory and balanced_ultra_low_memory. * remove device_map patch. * fix: intermediate variable scoping. * fix: condition in cpu offload. * fix: flax class restrictions. * remove modifications from cpu_offload and model_offload * incorporate changes. * add a simple forward pass test * add: torch_device in get_inputs() * add: tests * remove print * safe-guard to(), model offloading and cpu offloading when balanced is used as a device_map. * style * remove . * safeguard device_map with more checks and remove invalid device_mapping strategues. * make a class attribute and adjust tests accordingly. * fix device_map check * fix test * adjust comment * fix: device_map attribute * fix: dispatching. * max_memory test for pipeline * version guard the tests * fix guard. * address review feedback. * reset_device_map method. * add: test for reset_hf_device_map * fix a couple things. * add reset_device_map() in the error message. * add tests for checking reset_device_map doesn't have unintended consequences. * fix reset_device_map and offloading tests. * create _get_final_device_map utility. * hf_device_map -> _hf_device_map * add documentation * add notes suggested by Marc. * styling. * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * move updates within gpu condition. * other docs related things * note on ignore a device not specified in . * provide a suggestion if device mapping errors out. * fix: typo. * _hf_device_map -> hf_device_map * Empty-Commit * add: example hf_device_map. --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent c827e94 commit 3e4a6bd

File tree

7 files changed

+546
-17
lines changed

7 files changed

+546
-17
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,79 @@ To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](h
5252

5353
</Tip>
5454

55+
### Device placement
56+
57+
> [!WARNING]
58+
> This feature is experimental and its APIs might change in the future.
59+
60+
With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
61+
62+
For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
63+
64+
* it only works on a single GPU
65+
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
66+
67+
To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
68+
69+
> [!TIP]
70+
> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
71+
72+
```diff
73+
from diffusers import DiffusionPipeline
74+
import torch
75+
76+
pipeline = DiffusionPipeline.from_pretrained(
77+
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
78+
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
79+
)
80+
image = pipeline("a dog").images[0]
81+
image
82+
```
83+
84+
> [!WARNING]
85+
> Currently, we support only "balanced" `device_map`. We plan to support more device mapping strategies in future.
86+
87+
You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
88+
89+
```diff
90+
from diffusers import DiffusionPipeline
91+
import torch
92+
93+
max_memory = {0:"1GB", 1:"1GB"}
94+
pipeline = DiffusionPipeline.from_pretrained(
95+
"runwayml/stable-diffusion-v1-5",
96+
torch_dtype=torch.float16,
97+
use_safetensors=True,
98+
device_map="balanced",
99+
+ max_memory=max_memory
100+
)
101+
image = pipeline("a dog").images[0]
102+
image
103+
```
104+
105+
If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
106+
107+
By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
108+
109+
Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
110+
111+
```py
112+
pipeline.reset_device_map()
113+
```
114+
115+
Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
116+
117+
```py
118+
print(pipeline.hf_device_map)
119+
```
120+
121+
An example device map would look like so:
122+
123+
124+
```bash
125+
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
126+
```
127+
55128
## PyTorch Distributed
56129

57130
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
699699
offload_folder=offload_folder,
700700
offload_state_dict=offload_state_dict,
701701
dtype=torch_dtype,
702+
force_hooks=True,
702703
)
703704
except AttributeError as e:
704705
# When using accelerate loading, we do not have the ability to load the state

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 229 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@
2222
from typing import Any, Dict, List, Optional, Union
2323

2424
import torch
25-
from huggingface_hub import (
26-
model_info,
27-
)
25+
from huggingface_hub import model_info
26+
from huggingface_hub.utils import validate_hf_hub_args
2827
from packaging import version
2928

29+
from .. import __version__
3030
from ..utils import (
31+
FLAX_WEIGHTS_NAME,
32+
ONNX_EXTERNAL_WEIGHTS_NAME,
33+
ONNX_WEIGHTS_NAME,
3134
SAFETENSORS_WEIGHTS_NAME,
3235
WEIGHTS_NAME,
3336
get_class_from_dynamic_module,
37+
is_accelerate_available,
3438
is_peft_available,
3539
is_transformers_available,
3640
logging,
@@ -44,9 +48,12 @@
4448
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
4549
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
4650
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
47-
from huggingface_hub.utils import validate_hf_hub_args
4851

49-
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
52+
if is_accelerate_available():
53+
import accelerate
54+
from accelerate import dispatch_model
55+
from accelerate.hooks import remove_hook_from_module
56+
from accelerate.utils import compute_module_sizes, get_max_memory
5057

5158

5259
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -376,6 +383,207 @@ def _get_pipeline_class(
376383
return pipeline_cls
377384

378385

386+
def _load_empty_model(
387+
library_name: str,
388+
class_name: str,
389+
importable_classes: List[Any],
390+
pipelines: Any,
391+
is_pipeline_module: bool,
392+
name: str,
393+
torch_dtype: Union[str, torch.dtype],
394+
cached_folder: Union[str, os.PathLike],
395+
**kwargs,
396+
):
397+
# retrieve class objects.
398+
class_obj, _ = get_class_obj_and_candidates(
399+
library_name,
400+
class_name,
401+
importable_classes,
402+
pipelines,
403+
is_pipeline_module,
404+
component_name=name,
405+
cache_dir=cached_folder,
406+
)
407+
408+
if is_transformers_available():
409+
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
410+
else:
411+
transformers_version = "N/A"
412+
413+
# Determine library.
414+
is_transformers_model = (
415+
is_transformers_available()
416+
and issubclass(class_obj, PreTrainedModel)
417+
and transformers_version >= version.parse("4.20.0")
418+
)
419+
diffusers_module = importlib.import_module(__name__.split(".")[0])
420+
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
421+
422+
model = None
423+
config_path = cached_folder
424+
user_agent = {
425+
"diffusers": __version__,
426+
"file_type": "model",
427+
"framework": "pytorch",
428+
}
429+
430+
if is_diffusers_model:
431+
# Load config and then the model on meta.
432+
config, unused_kwargs, commit_hash = class_obj.load_config(
433+
os.path.join(config_path, name),
434+
cache_dir=cached_folder,
435+
return_unused_kwargs=True,
436+
return_commit_hash=True,
437+
force_download=kwargs.pop("force_download", False),
438+
resume_download=kwargs.pop("resume_download", False),
439+
proxies=kwargs.pop("proxies", None),
440+
local_files_only=kwargs.pop("local_files_only", False),
441+
token=kwargs.pop("token", None),
442+
revision=kwargs.pop("revision", None),
443+
subfolder=kwargs.pop("subfolder", None),
444+
user_agent=user_agent,
445+
)
446+
with accelerate.init_empty_weights():
447+
model = class_obj.from_config(config, **unused_kwargs)
448+
elif is_transformers_model:
449+
config_class = getattr(class_obj, "config_class", None)
450+
if config_class is None:
451+
raise ValueError("`config_class` cannot be None. Please double-check the model.")
452+
453+
config = config_class.from_pretrained(
454+
cached_folder,
455+
subfolder=name,
456+
force_download=kwargs.pop("force_download", False),
457+
resume_download=kwargs.pop("resume_download", False),
458+
proxies=kwargs.pop("proxies", None),
459+
local_files_only=kwargs.pop("local_files_only", False),
460+
token=kwargs.pop("token", None),
461+
revision=kwargs.pop("revision", None),
462+
user_agent=user_agent,
463+
)
464+
with accelerate.init_empty_weights():
465+
model = class_obj(config)
466+
467+
if model is not None:
468+
model = model.to(dtype=torch_dtype)
469+
return model
470+
471+
472+
def _assign_components_to_devices(
473+
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
474+
):
475+
device_ids = list(device_memory.keys())
476+
device_cycle = device_ids + device_ids[::-1]
477+
device_memory = device_memory.copy()
478+
479+
device_id_component_mapping = {}
480+
current_device_index = 0
481+
for component in module_sizes:
482+
device_id = device_cycle[current_device_index % len(device_cycle)]
483+
component_memory = module_sizes[component]
484+
curr_device_memory = device_memory[device_id]
485+
486+
# If the GPU doesn't fit the current component offload to the CPU.
487+
if component_memory > curr_device_memory:
488+
device_id_component_mapping["cpu"] = [component]
489+
else:
490+
if device_id not in device_id_component_mapping:
491+
device_id_component_mapping[device_id] = [component]
492+
else:
493+
device_id_component_mapping[device_id].append(component)
494+
495+
# Update the device memory.
496+
device_memory[device_id] -= component_memory
497+
current_device_index += 1
498+
499+
return device_id_component_mapping
500+
501+
502+
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
503+
# To avoid circular import problem.
504+
from diffusers import pipelines
505+
506+
torch_dtype = kwargs.get("torch_dtype", torch.float32)
507+
508+
# Load each module in the pipeline on a meta device so that we can derive the device map.
509+
init_empty_modules = {}
510+
for name, (library_name, class_name) in init_dict.items():
511+
if class_name.startswith("Flax"):
512+
raise ValueError("Flax pipelines are not supported with `device_map`.")
513+
514+
# Define all importable classes
515+
is_pipeline_module = hasattr(pipelines, library_name)
516+
importable_classes = ALL_IMPORTABLE_CLASSES
517+
loaded_sub_model = None
518+
519+
# Use passed sub model or load class_name from library_name
520+
if name in passed_class_obj:
521+
# if the model is in a pipeline module, then we load it from the pipeline
522+
# check that passed_class_obj has correct parent class
523+
maybe_raise_or_warn(
524+
library_name,
525+
library,
526+
class_name,
527+
importable_classes,
528+
passed_class_obj,
529+
name,
530+
is_pipeline_module,
531+
)
532+
with accelerate.init_empty_weights():
533+
loaded_sub_model = passed_class_obj[name]
534+
535+
else:
536+
loaded_sub_model = _load_empty_model(
537+
library_name=library_name,
538+
class_name=class_name,
539+
importable_classes=importable_classes,
540+
pipelines=pipelines,
541+
is_pipeline_module=is_pipeline_module,
542+
pipeline_class=pipeline_class,
543+
name=name,
544+
torch_dtype=torch_dtype,
545+
cached_folder=kwargs.get("cached_folder", None),
546+
force_download=kwargs.get("force_download", None),
547+
resume_download=kwargs.get("resume_download", None),
548+
proxies=kwargs.get("proxies", None),
549+
local_files_only=kwargs.get("local_files_only", None),
550+
token=kwargs.get("token", None),
551+
revision=kwargs.get("revision", None),
552+
)
553+
554+
if loaded_sub_model is not None:
555+
init_empty_modules[name] = loaded_sub_model
556+
557+
# determine device map
558+
# Obtain a sorted dictionary for mapping the model-level components
559+
# to their sizes.
560+
module_sizes = {
561+
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
562+
for module_name, module in init_empty_modules.items()
563+
if isinstance(module, torch.nn.Module)
564+
}
565+
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
566+
567+
# Obtain maximum memory available per device (GPUs only).
568+
max_memory = get_max_memory(max_memory)
569+
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
570+
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
571+
572+
# Obtain a dictionary mapping the model-level components to the available
573+
# devices based on the maximum memory and the model sizes.
574+
device_id_component_mapping = _assign_components_to_devices(
575+
module_sizes, max_memory, device_mapping_strategy=device_map
576+
)
577+
578+
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
579+
final_device_map = {}
580+
for device_id, components in device_id_component_mapping.items():
581+
for component in components:
582+
final_device_map[component] = device_id
583+
584+
return final_device_map
585+
586+
379587
def load_sub_model(
380588
library_name: str,
381589
class_name: str,
@@ -493,6 +701,22 @@ def load_sub_model(
493701
# else load from the root directory
494702
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
495703

704+
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
705+
# remove hooks
706+
remove_hook_from_module(loaded_sub_model, recurse=True)
707+
needs_offloading_to_cpu = device_map[""] == "cpu"
708+
709+
if needs_offloading_to_cpu:
710+
dispatch_model(
711+
loaded_sub_model,
712+
state_dict=loaded_sub_model.state_dict(),
713+
device_map=device_map,
714+
force_hooks=True,
715+
main_device=0,
716+
)
717+
else:
718+
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
719+
496720
return loaded_sub_model
497721

498722

0 commit comments

Comments
 (0)