Skip to content

Commit c340f9e

Browse files
authored
Merge branch 'main' into xpu
2 parents fd618b5 + d3e27e0 commit c340f9e

File tree

8 files changed

+68
-28
lines changed

8 files changed

+68
-28
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ def __init__(
9696
else:
9797
self.cpu_param_dict = self._init_cpu_param_dict()
9898

99-
if self.stream is None and self.record_stream:
100-
raise ValueError("`record_stream` cannot be True when `stream` is None.")
101-
10299
def _init_cpu_param_dict(self):
103100
cpu_param_dict = {}
104101
if self.stream is None:
@@ -513,6 +510,9 @@ def apply_group_offloading(
513510
else:
514511
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
515512

513+
if not use_stream and record_stream:
514+
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
515+
516516
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
517517

518518
if offload_type == "block_level":

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
110110
self.patch_size = patch_size
111111
self.patch_method = patch_method
112112

113-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
114-
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
113+
wavelets = _WAVELETS.get(patch_method).clone()
114+
arange = torch.arange(wavelets.shape[0])
115+
116+
self.register_buffer("wavelets", wavelets, persistent=False)
117+
self.register_buffer("_arange", arange, persistent=False)
115118

116119
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
117120
dtype = hidden_states.dtype
@@ -185,12 +188,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"):
185188
self.patch_size = patch_size
186189
self.patch_method = patch_method
187190

188-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
189-
self.register_buffer(
190-
"_arange",
191-
torch.arange(_WAVELETS[patch_method].shape[0]),
192-
persistent=False,
193-
)
191+
wavelets = _WAVELETS.get(patch_method).clone()
192+
arange = torch.arange(wavelets.shape[0])
193+
194+
self.register_buffer("wavelets", wavelets, persistent=False)
195+
self.register_buffer("_arange", arange, persistent=False)
194196

195197
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
196198
device = hidden_states.device

src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from ...models.autoencoders import AutoencoderKL
2424
from ...models.transformers import OmniGenTransformer2DModel
2525
from ...schedulers import FlowMatchEulerDiscreteScheduler
26-
from ...utils import is_torch_xla_available, logging, replace_example_docstring
26+
from ...utils import is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring
2727
from ...utils.torch_utils import randn_tensor
2828
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29-
from .processor_omnigen import OmniGenMultiModalProcessor
3029

3130

31+
if is_torchvision_available():
32+
from .processor_omnigen import OmniGenMultiModalProcessor
33+
3234
if is_torch_xla_available():
3335
XLA_AVAILABLE = True
3436
else:

src/diffusers/pipelines/omnigen/processor_omnigen.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
import numpy as np
1919
import torch
2020
from PIL import Image
21-
from torchvision import transforms
21+
22+
from ...utils import is_torchvision_available
23+
24+
25+
if is_torchvision_available():
26+
from torchvision import transforms
2227

2328

2429
def crop_image(pil_image, max_image_size):

tests/models/test_modeling_common.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,14 +1527,16 @@ def test_fn(storage_dtype, compute_dtype):
15271527
test_fn(torch.float8_e5m2, torch.float32)
15281528
test_fn(torch.float8_e4m3fn, torch.bfloat16)
15291529

1530+
@torch.no_grad()
15301531
def test_layerwise_casting_inference(self):
15311532
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
15321533

15331534
torch.manual_seed(0)
15341535
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1535-
model = self.model_class(**config).eval()
1536-
model = model.to(torch_device)
1537-
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1536+
model = self.model_class(**config)
1537+
model.eval()
1538+
model.to(torch_device)
1539+
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
15381540

15391541
def check_linear_dtype(module, storage_dtype, compute_dtype):
15401542
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1572,6 +1574,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
15721574
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
15731575

15741576
@require_torch_accelerator
1577+
@torch.no_grad()
15751578
def test_layerwise_casting_memory(self):
15761579
MB_TOLERANCE = 0.2
15771580
LEAST_COMPUTE_CAPABILITY = 8.0
@@ -1705,10 +1708,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17051708
if not self.model_class._supports_group_offloading:
17061709
pytest.skip("Model does not support group offloading.")
17071710

1708-
torch.manual_seed(0)
1709-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1710-
model = self.model_class(**init_dict)
1711-
17121711
torch.manual_seed(0)
17131712
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17141713
model = self.model_class(**init_dict)
@@ -1724,7 +1723,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17241723
**additional_kwargs,
17251724
)
17261725
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1727-
assert has_safetensors, "No safetensors found in the directory."
1726+
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
17281727
_ = model(**inputs_dict)[0]
17291728

17301729
def test_auto_model(self, expected_max_diff=5e-5):

tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class HunyuanVideoFramepackPipelineFastTests(
7171
)
7272

7373
supports_dduf = False
74-
# there is no xformers processor for Flux
7574
test_xformers_attention = False
7675
test_layerwise_casting = True
7776
test_group_offloading = True
@@ -360,6 +359,30 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
360359
"VAE tiling should not affect the inference results",
361360
)
362361

362+
def test_float16_inference(self, expected_max_diff=0.2):
363+
# NOTE: this test needs a higher tolerance because of multiple forwards through
364+
# the model, which compounds the overall fp32 vs fp16 numerical differences. It
365+
# shouldn't be expected that the results are the same, so we bump the tolerance.
366+
return super().test_float16_inference(expected_max_diff)
367+
368+
@unittest.skip("The image_encoder uses SiglipVisionModel, which does not support sequential CPU offloading.")
369+
def test_sequential_cpu_offload_forward_pass(self):
370+
# https://github.com/huggingface/transformers/blob/21cb353b7b4f77c6f5f5c3341d660f86ff416d04/src/transformers/models/siglip/modeling_siglip.py#L803
371+
# This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
372+
# `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
373+
# triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
374+
# this test because of MHA (example: HunyuanDiT because of AttentionPooling layer).
375+
pass
376+
377+
@unittest.skip("The image_encoder uses SiglipVisionModel, which does not support sequential CPU offloading.")
378+
def test_sequential_offload_forward_pass_twice(self):
379+
# https://github.com/huggingface/transformers/blob/21cb353b7b4f77c6f5f5c3341d660f86ff416d04/src/transformers/models/siglip/modeling_siglip.py#L803
380+
# This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
381+
# `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
382+
# triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
383+
# this test because of MHA (example: HunyuanDiT because of AttentionPooling layer).
384+
pass
385+
363386
# TODO(aryan): Create a dummy gemma model with smol vocab size
364387
@unittest.skip(
365388
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."

tests/pipelines/hunyuandit/test_hunyuan_dit.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,22 @@ def test_inference(self):
124124
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
125125
self.assertLessEqual(max_diff, 1e-3)
126126

127-
@unittest.skip("Not supported.")
127+
@unittest.skip("The HunyuanDiT Attention pooling layer does not support sequential CPU offloading.")
128128
def test_sequential_cpu_offload_forward_pass(self):
129129
# TODO(YiYi) need to fix later
130+
# This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
131+
# `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
132+
# triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
133+
# this test because of MHA (example: HunyuanVideo Framepack)
130134
pass
131135

132-
@unittest.skip("Not supported.")
136+
@unittest.skip("The HunyuanDiT Attention pooling layer does not support sequential CPU offloading.")
133137
def test_sequential_offload_forward_pass_twice(self):
134138
# TODO(YiYi) need to fix later
139+
# This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
140+
# `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
141+
# triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
142+
# this test because of MHA (example: HunyuanVideo Framepack)
135143
pass
136144

137145
def test_inference_batch_single_identical(self):

tests/pipelines/test_pipelines_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,9 +2270,10 @@ def enable_group_offload_on_component(pipe, group_offloading_kwargs):
22702270
if hasattr(module, "_diffusers_hook")
22712271
)
22722272
)
2273-
for component_name in ["vae", "vqvae"]:
2274-
if hasattr(pipe, component_name):
2275-
getattr(pipe, component_name).to(torch_device)
2273+
for component_name in ["vae", "vqvae", "image_encoder"]:
2274+
component = getattr(pipe, component_name, None)
2275+
if isinstance(component, torch.nn.Module):
2276+
component.to(torch_device)
22762277

22772278
def run_forward(pipe):
22782279
torch.manual_seed(0)

0 commit comments

Comments
 (0)