Skip to content

Commit 12eeb25

Browse files
committed
update
1 parent 1ddf3f3 commit 12eeb25

File tree

3 files changed

+8
-44
lines changed

3 files changed

+8
-44
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def load_model_dict_into_meta(
264264
old_param = None
265265

266266
if old_param is not None:
267-
if dtype is None:
267+
# Do not cast parameters if the model is quantized
268+
if dtype is None and hf_quantizer is None:
268269
param = param.to(old_param.dtype)
269270

270271
if old_param.is_contiguous():

src/diffusers/utils/testing_utils.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -320,21 +320,6 @@ def require_torch_multi_gpu(test_case):
320320
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
321321

322322

323-
def require_torch_multi_accelerator(test_case):
324-
"""
325-
Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
326-
without multiple hardware accelerators.
327-
"""
328-
if not is_torch_available():
329-
return unittest.skip("test requires PyTorch")(test_case)
330-
331-
import torch
332-
333-
return unittest.skipUnless(
334-
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
335-
)(test_case)
336-
337-
338323
def require_torch_accelerator_with_fp16(test_case):
339324
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
340325
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
@@ -369,31 +354,6 @@ def require_big_gpu_with_torch_cuda(test_case):
369354
)(test_case)
370355

371356

372-
def require_big_accelerator(test_case):
373-
"""
374-
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
375-
Flux, SD3, Cog, etc.
376-
"""
377-
if not is_torch_available():
378-
return unittest.skip("test requires PyTorch")(test_case)
379-
380-
import torch
381-
382-
if not (torch.cuda.is_available() or torch.xpu.is_available()):
383-
return unittest.skip("test requires PyTorch CUDA")(test_case)
384-
385-
if torch.xpu.is_available():
386-
device_properties = torch.xpu.get_device_properties(0)
387-
else:
388-
device_properties = torch.cuda.get_device_properties(0)
389-
390-
total_memory = device_properties.total_memory / (1024**3)
391-
return unittest.skipUnless(
392-
total_memory >= BIG_GPU_MEMORY,
393-
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
394-
)(test_case)
395-
396-
397357
def require_torch_accelerator_with_training(test_case):
398358
"""Decorator marking a test that requires an accelerator with support for training."""
399359
return unittest.skipUnless(

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,16 @@ class Base8bitTests(unittest.TestCase):
9090

9191
def get_dummy_inputs(self):
9292
prompt_embeds = load_pt(
93-
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
93+
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
94+
map_location="cpu",
9495
)
9596
pooled_prompt_embeds = load_pt(
96-
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
97+
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
98+
map_location="cpu",
9799
)
98100
latent_model_input = load_pt(
99-
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
101+
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
102+
map_location="cpu",
100103
)
101104

102105
input_dict_for_transformer = {

0 commit comments

Comments
 (0)