Skip to content

Commit a31e59e

Browse files
committed
up
1 parent 7228ab8 commit a31e59e

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
numpy_to_pil,
6868
)
6969
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
70+
from ..utils.testing_utils import torch_device
7071
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
7172

7273

@@ -109,7 +110,7 @@
109110
LIBRARIES.append(library)
110111

111112
# TODO: support single-device namings
112-
SUPPORTED_DEVICE_MAP = ["balanced", "cuda"]
113+
SUPPORTED_DEVICE_MAP = ["balanced"] + [torch_device]
113114

114115
logger = logging.get_logger(__name__)
115116

tests/pipelines/test_pipelines_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,26 @@ def test_torch_dtype_dict(self):
23392339
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
23402340
)
23412341

2342+
@require_torch_accelerator
2343+
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
2344+
components = self.get_dummy_components()
2345+
pipe = self.pipeline_class(**components)
2346+
pipe = pipe.to(torch_device)
2347+
pipe.set_progress_bar_config(disable=None)
2348+
2349+
torch.manual_seed(0)
2350+
inputs = self.get_dummy_inputs(torch_device)
2351+
inputs["generator"] = torch.manual_seed(0)
2352+
out = pipe(**inputs)[0]
2353+
2354+
with tempfile.TemporaryDirectory() as tmpdir:
2355+
pipe.save_pretrained(tmpdir)
2356+
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device)
2357+
inputs = self.get_dummy_inputs(torch_device)
2358+
loaded_out = loaded_pipe(**inputs)[0]
2359+
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
2360+
self.assertLess(max_diff, expected_max_difference)
2361+
23422362

23432363
@is_staging_test
23442364
class PipelinePushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)