Skip to content

Commit 4622c11

Browse files
authored
Merge branch 'main' into reuse-attn-mixin
2 parents 21cab86 + bc40398 commit 4622c11

File tree

7 files changed

+220
-5
lines changed

7 files changed

+220
-5
lines changed

src/diffusers/models/embeddings.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
319319
return emb
320320

321321

322-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
322+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
323323
"""
324324
This function generates 1D positional embeddings from a grid.
325325
326326
Args:
327327
embed_dim (`int`): The embedding dimension `D`
328328
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
329+
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
330+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
331+
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
332+
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
329333
330334
Returns:
331335
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
341345
if embed_dim % 2 != 0:
342346
raise ValueError("embed_dim must be divisible by 2")
343347

344-
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
348+
# Auto-detect appropriate dtype if not specified
349+
if dtype is None:
350+
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
351+
352+
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
345353
omega /= embed_dim / 2.0
346354
omega = 1.0 / 10000**omega # (D/2,)
347355

src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def encode_prompt(
113113
negative_prompt=None,
114114
prompt_embeds: Optional[torch.Tensor] = None,
115115
negative_prompt_embeds: Optional[torch.Tensor] = None,
116-
_cut_context=False,
116+
_cut_context=True,
117117
attention_mask: Optional[torch.Tensor] = None,
118118
negative_attention_mask: Optional[torch.Tensor] = None,
119119
):

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ONNX_WEIGHTS_NAME,
3434
SAFETENSORS_WEIGHTS_NAME,
3535
WEIGHTS_NAME,
36+
_maybe_remap_transformers_class,
3637
deprecate,
3738
get_class_from_dynamic_module,
3839
is_accelerate_available,
@@ -356,6 +357,11 @@ def maybe_raise_or_warn(
356357
"""Simple helper method to raise or warn in case incorrect module has been passed"""
357358
if not is_pipeline_module:
358359
library = importlib.import_module(library_name)
360+
361+
# Handle deprecated Transformers classes
362+
if library_name == "transformers":
363+
class_name = _maybe_remap_transformers_class(class_name) or class_name
364+
359365
class_obj = getattr(library, class_name)
360366
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
361367

@@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name):
390396
class_obj = getattr(pipeline_module, class_name)
391397
else:
392398
library = importlib.import_module(library_name)
399+
400+
# Handle deprecated Transformers classes
401+
if library_name == "transformers":
402+
class_name = _maybe_remap_transformers_class(class_name) or class_name
403+
393404
class_obj = getattr(library, class_name)
394405

395406
return class_obj
@@ -416,6 +427,10 @@ def get_class_obj_and_candidates(
416427
# else we just import it from the library.
417428
library = importlib.import_module(library_name)
418429

430+
# Handle deprecated Transformers classes
431+
if library_name == "transformers":
432+
class_name = _maybe_remap_transformers_class(class_name) or class_name
433+
419434
class_obj = getattr(library, class_name)
420435
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
421436

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
WEIGHTS_INDEX_NAME,
3939
WEIGHTS_NAME,
4040
)
41-
from .deprecation_utils import deprecate
41+
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
4242
from .doc_utils import replace_example_docstring
4343
from .dynamic_modules_utils import get_class_from_dynamic_module
4444
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video

src/diffusers/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
48-
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
48+
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
4949
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
5050

5151
# Below should be `True` if the current version of `peft` and `transformers` are compatible with

src/diffusers/utils/deprecation_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,54 @@
44

55
from packaging import version
66

7+
from ..utils import logging
8+
9+
10+
logger = logging.get_logger(__name__)
11+
12+
# Mapping for deprecated Transformers classes to their replacements
13+
# This is used to handle models that reference deprecated class names in their configs
14+
# Reference: https://github.com/huggingface/transformers/issues/40822
15+
# Format: {
16+
# "DeprecatedClassName": {
17+
# "new_class": "NewClassName",
18+
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
19+
# }
20+
# }
21+
_TRANSFORMERS_CLASS_REMAPPING = {
22+
"CLIPFeatureExtractor": {
23+
"new_class": "CLIPImageProcessor",
24+
"transformers_version": (">", "4.57.0"),
25+
},
26+
}
27+
28+
29+
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
30+
"""
31+
Check if a Transformers class should be remapped to a newer version.
32+
33+
Args:
34+
class_name: The name of the class to check
35+
36+
Returns:
37+
The new class name if remapping should occur, None otherwise
38+
"""
39+
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
40+
return None
41+
42+
from .import_utils import is_transformers_version
43+
44+
mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
45+
operation, required_version = mapping["transformers_version"]
46+
47+
# Only remap if the transformers version meets the requirement
48+
if is_transformers_version(operation, required_version):
49+
new_class = mapping["new_class"]
50+
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
51+
return mapping["new_class"]
52+
53+
return None
54+
755

856
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
957
from .. import __version__
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
3+
4+
Once attention backends become more mature, we can consider including this in our CI.
5+
6+
To run this test suite:
7+
8+
```bash
9+
export RUN_ATTENTION_BACKEND_TESTS=yes
10+
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
11+
12+
pytest tests/others/test_attention_backends.py
13+
```
14+
15+
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
16+
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
17+
"""
18+
19+
import os
20+
21+
import pytest
22+
import torch
23+
24+
25+
pytestmark = pytest.mark.skipif(
26+
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
27+
)
28+
from diffusers import FluxPipeline # noqa: E402
29+
from diffusers.utils import is_torch_version # noqa: E402
30+
31+
32+
# fmt: off
33+
FORWARD_CASES = [
34+
("flash_hub", None),
35+
(
36+
"_flash_3_hub",
37+
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
38+
),
39+
(
40+
"native",
41+
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
42+
),
43+
(
44+
"_native_cudnn",
45+
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
46+
),
47+
]
48+
49+
COMPILE_CASES = [
50+
("flash_hub", None, True),
51+
(
52+
"_flash_3_hub",
53+
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
54+
True,
55+
),
56+
(
57+
"native",
58+
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
59+
True,
60+
),
61+
(
62+
"_native_cudnn",
63+
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
64+
True,
65+
),
66+
]
67+
# fmt: on
68+
69+
INFER_KW = {
70+
"prompt": "dance doggo dance",
71+
"height": 256,
72+
"width": 256,
73+
"num_inference_steps": 2,
74+
"guidance_scale": 3.5,
75+
"max_sequence_length": 128,
76+
"output_type": "pt",
77+
}
78+
79+
80+
def _backend_is_probably_supported(pipe, name: str):
81+
try:
82+
pipe.transformer.set_attention_backend(name)
83+
return pipe, True
84+
except Exception:
85+
return False
86+
87+
88+
def _check_if_slices_match(output, expected_slice):
89+
img = output.images.detach().cpu()
90+
generated_slice = img.flatten()
91+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
92+
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
93+
94+
95+
@pytest.fixture(scope="session")
96+
def device():
97+
if not torch.cuda.is_available():
98+
pytest.skip("CUDA is required for these tests.")
99+
return torch.device("cuda:0")
100+
101+
102+
@pytest.fixture(scope="session")
103+
def pipe(device):
104+
repo_id = "black-forest-labs/FLUX.1-dev"
105+
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
106+
pipe.set_progress_bar_config(disable=True)
107+
return pipe
108+
109+
110+
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
111+
def test_forward(pipe, backend_name, expected_slice):
112+
out = _backend_is_probably_supported(pipe, backend_name)
113+
if isinstance(out, bool):
114+
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
115+
116+
modified_pipe = out[0]
117+
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
118+
_check_if_slices_match(out, expected_slice)
119+
120+
121+
@pytest.mark.parametrize(
122+
"backend_name,expected_slice,error_on_recompile",
123+
COMPILE_CASES,
124+
ids=[c[0] for c in COMPILE_CASES],
125+
)
126+
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
127+
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
128+
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
129+
130+
out = _backend_is_probably_supported(pipe, backend_name)
131+
if isinstance(out, bool):
132+
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
133+
134+
modified_pipe = out[0]
135+
modified_pipe.transformer.compile(fullgraph=True)
136+
137+
torch.compiler.reset()
138+
with (
139+
torch._inductor.utils.fresh_inductor_cache(),
140+
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
141+
):
142+
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
143+
144+
_check_if_slices_match(out, expected_slice)

0 commit comments

Comments
 (0)