Skip to content

Commit b76ea11

Browse files
authored
Merge branch 'main' into skip-better-2
2 parents fb2813d + d3e27e0 commit b76ea11

File tree

5 files changed

+30
-22
lines changed

5 files changed

+30
-22
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ def __init__(
9696
else:
9797
self.cpu_param_dict = self._init_cpu_param_dict()
9898

99-
if self.stream is None and self.record_stream:
100-
raise ValueError("`record_stream` cannot be True when `stream` is None.")
101-
10299
def _init_cpu_param_dict(self):
103100
cpu_param_dict = {}
104101
if self.stream is None:
@@ -513,6 +510,9 @@ def apply_group_offloading(
513510
else:
514511
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
515512

513+
if not use_stream and record_stream:
514+
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
515+
516516
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
517517

518518
if offload_type == "block_level":

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
110110
self.patch_size = patch_size
111111
self.patch_method = patch_method
112112

113-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
114-
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
113+
wavelets = _WAVELETS.get(patch_method).clone()
114+
arange = torch.arange(wavelets.shape[0])
115+
116+
self.register_buffer("wavelets", wavelets, persistent=False)
117+
self.register_buffer("_arange", arange, persistent=False)
115118

116119
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
117120
dtype = hidden_states.dtype
@@ -185,12 +188,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"):
185188
self.patch_size = patch_size
186189
self.patch_method = patch_method
187190

188-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
189-
self.register_buffer(
190-
"_arange",
191-
torch.arange(_WAVELETS[patch_method].shape[0]),
192-
persistent=False,
193-
)
191+
wavelets = _WAVELETS.get(patch_method).clone()
192+
arange = torch.arange(wavelets.shape[0])
193+
194+
self.register_buffer("wavelets", wavelets, persistent=False)
195+
self.register_buffer("_arange", arange, persistent=False)
194196

195197
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
196198
device = hidden_states.device

src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from ...models.autoencoders import AutoencoderKL
2424
from ...models.transformers import OmniGenTransformer2DModel
2525
from ...schedulers import FlowMatchEulerDiscreteScheduler
26-
from ...utils import is_torch_xla_available, logging, replace_example_docstring
26+
from ...utils import is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring
2727
from ...utils.torch_utils import randn_tensor
2828
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29-
from .processor_omnigen import OmniGenMultiModalProcessor
3029

3130

31+
if is_torchvision_available():
32+
from .processor_omnigen import OmniGenMultiModalProcessor
33+
3234
if is_torch_xla_available():
3335
XLA_AVAILABLE = True
3436
else:

src/diffusers/pipelines/omnigen/processor_omnigen.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
import numpy as np
1919
import torch
2020
from PIL import Image
21-
from torchvision import transforms
21+
22+
from ...utils import is_torchvision_available
23+
24+
25+
if is_torchvision_available():
26+
from torchvision import transforms
2227

2328

2429
def crop_image(pil_image, max_image_size):

tests/models/test_modeling_common.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype):
15281528
test_fn(torch.float8_e5m2, torch.float32)
15291529
test_fn(torch.float8_e4m3fn, torch.bfloat16)
15301530

1531+
@torch.no_grad()
15311532
def test_layerwise_casting_inference(self):
15321533
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
15331534

15341535
torch.manual_seed(0)
15351536
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1536-
model = self.model_class(**config).eval()
1537-
model = model.to(torch_device)
1538-
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1537+
model = self.model_class(**config)
1538+
model.eval()
1539+
model.to(torch_device)
1540+
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
15391541

15401542
def check_linear_dtype(module, storage_dtype, compute_dtype):
15411543
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1573,6 +1575,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
15731575
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
15741576

15751577
@require_torch_accelerator
1578+
@torch.no_grad()
15761579
def test_layerwise_casting_memory(self):
15771580
MB_TOLERANCE = 0.2
15781581
LEAST_COMPUTE_CAPABILITY = 8.0
@@ -1706,10 +1709,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17061709
if not self.model_class._supports_group_offloading:
17071710
pytest.skip("Model does not support group offloading.")
17081711

1709-
torch.manual_seed(0)
1710-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1711-
model = self.model_class(**init_dict)
1712-
17131712
torch.manual_seed(0)
17141713
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17151714
model = self.model_class(**init_dict)
@@ -1725,7 +1724,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17251724
**additional_kwargs,
17261725
)
17271726
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1728-
assert has_safetensors, "No safetensors found in the directory."
1727+
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
17291728
_ = model(**inputs_dict)[0]
17301729

17311730
def test_auto_model(self, expected_max_diff=5e-5):

0 commit comments

Comments
 (0)