Skip to content

Commit 0305b5a

Browse files
authored
Merge branch 'main' into enable-compilation
2 parents 20e30cb + 4b17fa2 commit 0305b5a

File tree

10 files changed

+366
-3
lines changed

10 files changed

+366
-3
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
"AutoGuidance",
140140
"ClassifierFreeGuidance",
141141
"ClassifierFreeZeroStarGuidance",
142+
"FrequencyDecoupledGuidance",
142143
"PerturbedAttentionGuidance",
143144
"SkipLayerGuidance",
144145
"SmoothedEnergyGuidance",
@@ -804,6 +805,7 @@
804805
AutoGuidance,
805806
ClassifierFreeGuidance,
806807
ClassifierFreeZeroStarGuidance,
808+
FrequencyDecoupledGuidance,
807809
PerturbedAttentionGuidance,
808810
SkipLayerGuidance,
809811
SmoothedEnergyGuidance,

src/diffusers/guiders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .auto_guidance import AutoGuidance
2323
from .classifier_free_guidance import ClassifierFreeGuidance
2424
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
25+
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
2526
from .perturbed_attention_guidance import PerturbedAttentionGuidance
2627
from .skip_layer_guidance import SkipLayerGuidance
2728
from .smoothed_energy_guidance import SmoothedEnergyGuidance
@@ -32,6 +33,7 @@
3233
AutoGuidance,
3334
ClassifierFreeGuidance,
3435
ClassifierFreeZeroStarGuidance,
36+
FrequencyDecoupledGuidance,
3537
PerturbedAttentionGuidance,
3638
SkipLayerGuidance,
3739
SmoothedEnergyGuidance,

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 327 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def forward(
384384
temb: torch.Tensor,
385385
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
386386
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
387-
) -> torch.Tensor:
387+
) -> Tuple[torch.Tensor, torch.Tensor]:
388388
text_seq_len = encoder_hidden_states.shape[1]
389389
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
390390

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
is_k_diffusion_available,
8383
is_k_diffusion_version,
8484
is_kernels_available,
85+
is_kornia_available,
8586
is_librosa_available,
8687
is_matplotlib_available,
8788
is_nltk_available,

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs):
6262
requires_backends(cls, ["torch"])
6363

6464

65+
class FrequencyDecoupledGuidance(metaclass=DummyObject):
66+
_backends = ["torch"]
67+
68+
def __init__(self, *args, **kwargs):
69+
requires_backends(self, ["torch"])
70+
71+
@classmethod
72+
def from_config(cls, *args, **kwargs):
73+
requires_backends(cls, ["torch"])
74+
75+
@classmethod
76+
def from_pretrained(cls, *args, **kwargs):
77+
requires_backends(cls, ["torch"])
78+
79+
6580
class PerturbedAttentionGuidance(metaclass=DummyObject):
6681
_backends = ["torch"]
6782

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
224224
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
225225
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
226226
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
227+
_kornia_available, _kornia_version = _is_package_available("kornia")
227228

228229

229230
def is_torch_available():
@@ -398,6 +399,10 @@ def is_flash_attn_3_available():
398399
return _flash_attn_3_available
399400

400401

402+
def is_kornia_available():
403+
return _kornia_available
404+
405+
401406
# docstyle-ignore
402407
FLAX_IMPORT_ERROR = """
403408
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the

tests/quantization/bnb/test_4bit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,7 @@ def quantization_config(self):
886886
components_to_quantize=["transformer", "text_encoder_2"],
887887
)
888888

889+
@require_bitsandbytes_version_greater("0.46.1")
889890
def test_torch_compile(self):
890891
torch._dynamo.config.capture_dynamic_output_shape_ops = True
891892
super().test_torch_compile()

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,10 @@ def quantization_config(self):
847847
components_to_quantize=["transformer", "text_encoder_2"],
848848
)
849849

850+
@pytest.mark.xfail(
851+
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
852+
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
853+
)
850854
def test_torch_compile(self):
851855
torch._dynamo.config.capture_dynamic_output_shape_ops = True
852856
super()._test_torch_compile(torch_dtype=torch.float16)

tests/quantization/test_torch_compile_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,18 @@ def _test_torch_compile(self, torch_dtype=torch.bfloat16):
5656
pipe.transformer.compile(fullgraph=True)
5757

5858
# small resolutions to ensure speedy execution.
59-
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
59+
with torch._dynamo.config.patch(error_on_recompile=True):
60+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
6061

6162
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
6263
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
6364
pipe.enable_model_cpu_offload()
64-
pipe.transformer.compile()
65+
# regional compilation is better for offloading.
66+
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
67+
if getattr(pipe.transformer, "_repeated_blocks"):
68+
pipe.transformer.compile_repeated_blocks(fullgraph=True)
69+
else:
70+
pipe.transformer.compile()
6571

6672
# small resolutions to ensure speedy execution.
6773
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)

0 commit comments

Comments
 (0)