Skip to content

Commit 46a0c6a

Browse files
authored
feat: cuda device_map for pipelines. (#12122)
* feat: cuda device_map for pipelines. * up * up * empty * up
1 parent 421ee07 commit 46a0c6a

File tree

4 files changed

+38
-7
lines changed

4 files changed

+38
-7
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,9 @@ def _assign_components_to_devices(
613613

614614

615615
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
616+
# TODO: seperate out different device_map methods when it gets to it.
617+
if device_map != "balanced":
618+
return device_map
616619
# To avoid circular import problem.
617620
from diffusers import pipelines
618621

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
for library in LOADABLE_CLASSES:
109109
LIBRARIES.append(library)
110110

111-
SUPPORTED_DEVICE_MAP = ["balanced"]
111+
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
112112

113113
logger = logging.get_logger(__name__)
114114

@@ -988,12 +988,15 @@ def load_module(name, value):
988988
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
989989
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
990990
# 7.1 device_map shenanigans
991-
if final_device_map is not None and len(final_device_map) > 0:
992-
component_device = final_device_map.get(name, None)
993-
if component_device is not None:
994-
current_device_map = {"": component_device}
995-
else:
996-
current_device_map = None
991+
if final_device_map is not None:
992+
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
993+
component_device = final_device_map.get(name, None)
994+
if component_device is not None:
995+
current_device_map = {"": component_device}
996+
else:
997+
current_device_map = None
998+
elif isinstance(final_device_map, str):
999+
current_device_map = final_device_map
9971000

9981001
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
9991002
class_name = class_name[4:] if class_name.startswith("Flax") else class_name

src/diffusers/utils/torch_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
PyTorch utilities: Utilities related to PyTorch
1616
"""
1717

18+
import functools
1819
from typing import List, Optional, Tuple, Union
1920

2021
from . import logging
@@ -168,6 +169,7 @@ def get_torch_cuda_device_capability():
168169
return None
169170

170171

172+
@functools.lru_cache
171173
def get_device():
172174
if torch.cuda.is_available():
173175
return "cuda"

tests/pipelines/test_pipelines_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,29 @@ def test_torch_dtype_dict(self):
23392339
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
23402340
)
23412341

2342+
@require_torch_accelerator
2343+
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
2344+
components = self.get_dummy_components()
2345+
pipe = self.pipeline_class(**components)
2346+
pipe = pipe.to(torch_device)
2347+
pipe.set_progress_bar_config(disable=None)
2348+
2349+
torch.manual_seed(0)
2350+
inputs = self.get_dummy_inputs(torch_device)
2351+
inputs["generator"] = torch.manual_seed(0)
2352+
out = pipe(**inputs)[0]
2353+
2354+
with tempfile.TemporaryDirectory() as tmpdir:
2355+
pipe.save_pretrained(tmpdir)
2356+
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device)
2357+
for component in loaded_pipe.components.values():
2358+
if hasattr(component, "set_default_attn_processor"):
2359+
component.set_default_attn_processor()
2360+
inputs["generator"] = torch.manual_seed(0)
2361+
loaded_out = loaded_pipe(**inputs)[0]
2362+
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
2363+
self.assertLess(max_diff, expected_max_difference)
2364+
23422365

23432366
@is_staging_test
23442367
class PipelinePushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)