Skip to content

Commit 05b706c

Browse files
authored
PAG variant for AnimateDiff (#8789)
* add animatediff pag pipeline * remove unnecessary print * make fix-copies * fix ip-adapter bug * update docs * add fast tests and fix bugs * update * update * address review comments * update ip adapter single test expected slice * implement test_from_pipe_consistent_config; fix expected slice values * LoraLoaderMixin->StableDiffusionLoraLoaderMixin; add latest freeinit test
1 parent ea1b4ea commit 05b706c

File tree

8 files changed

+1395
-14
lines changed

8 files changed

+1395
-14
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ The abstract from the paper is:
2020

2121
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
2222

23+
## AnimateDiffPAGPipeline
24+
[[autodoc]] AnimateDiffPAGPipeline
25+
- all
26+
- __call__
27+
2328
## StableDiffusionPAGPipeline
2429
[[autodoc]] StableDiffusionPAGPipeline
2530
- all

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@
233233
"AmusedInpaintPipeline",
234234
"AmusedPipeline",
235235
"AnimateDiffControlNetPipeline",
236+
"AnimateDiffPAGPipeline",
236237
"AnimateDiffPipeline",
237238
"AnimateDiffSDXLPipeline",
238239
"AnimateDiffSparseControlNetPipeline",
@@ -654,6 +655,7 @@
654655
AmusedInpaintPipeline,
655656
AmusedPipeline,
656657
AnimateDiffControlNetPipeline,
658+
AnimateDiffPAGPipeline,
657659
AnimateDiffPipeline,
658660
AnimateDiffSDXLPipeline,
659661
AnimateDiffSparseControlNetPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
)
144144
_import_structure["pag"].extend(
145145
[
146+
"AnimateDiffPAGPipeline",
146147
"StableDiffusionPAGPipeline",
147148
"StableDiffusionControlNetPAGPipeline",
148149
"StableDiffusionXLPAGPipeline",
@@ -527,6 +528,7 @@
527528
)
528529
from .musicldm import MusicLDMPipeline
529530
from .pag import (
531+
AnimateDiffPAGPipeline,
530532
StableDiffusionControlNetPAGPipeline,
531533
StableDiffusionPAGPipeline,
532534
StableDiffusionXLControlNetPAGPipeline,

src/diffusers/pipelines/pag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
2626
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
2727
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
28+
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
2829
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
2930
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
3031
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -40,6 +41,7 @@
4041
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
4142
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
4243
from .pipeline_pag_sd import StableDiffusionPAGPipeline
44+
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
4345
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
4446
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
4547
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline

src/diffusers/pipelines/pag/pag_utils.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _check_input_pag_applied_layer(layer):
3333
Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
3434
"{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
3535
can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
36-
in the format of "attentions_{j}".
36+
in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}"
3737
"""
3838

3939
layer_splits = layer.split(".")
@@ -52,8 +52,11 @@ def _check_input_pag_applied_layer(layer):
5252
raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")
5353

5454
if len(layer_splits) == 3:
55-
if not layer_splits[2].startswith("attentions_"):
56-
raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'")
55+
layer_2 = layer_splits[2]
56+
if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"):
57+
raise ValueError(
58+
f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'"
59+
)
5760

5861
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
5962
r"""
@@ -72,33 +75,46 @@ def is_self_attn(module_name):
7275

7376
def get_block_type(module_name):
7477
r"""
75-
Get the block type from the module name. can be "down", "mid", "up".
78+
Get the block type from the module name. Can be "down", "mid", "up".
7679
"""
7780
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
81+
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down"
7882
return module_name.split(".")[0].split("_")[0]
7983

8084
def get_block_index(module_name):
8185
r"""
82-
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
86+
Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g.
8387
mid_block) and index is ommited from the name, it will be "block_0".
8488
"""
8589
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
8690
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
87-
if "attentions" in module_name.split(".")[1]:
91+
module_name_splits = module_name.split(".")
92+
block_index = module_name_splits[1]
93+
if "attentions" in block_index or "motion_modules" in block_index:
8894
return "block_0"
8995
else:
90-
return f"block_{module_name.split('.')[1]}"
96+
return f"block_{block_index}"
9197

9298
def get_attn_index(module_name):
9399
r"""
94-
Get the attention index from the module name. can be "attentions_0", "attentions_1", ...
100+
Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0",
101+
"motion_modules_1", ...
95102
"""
96103
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
97104
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
98-
if "attentions" in module_name.split(".")[2]:
99-
return f"attentions_{module_name.split('.')[3]}"
100-
elif "attentions" in module_name.split(".")[1]:
101-
return f"attentions_{module_name.split('.')[2]}"
105+
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
106+
# mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
107+
module_name_split = module_name.split(".")
108+
mid_name = module_name_split[1]
109+
down_name = module_name_split[2]
110+
if "attentions" in down_name:
111+
return f"attentions_{module_name_split[3]}"
112+
if "attentions" in mid_name:
113+
return f"attentions_{module_name_split[2]}"
114+
if "motion_modules" in down_name:
115+
return f"motion_modules_{module_name_split[3]}"
116+
if "motion_modules" in mid_name:
117+
return f"motion_modules_{module_name_split[2]}"
102118

103119
for pag_layer_input in pag_applied_layers:
104120
# for each PAG layer input, we find corresponding self-attention layers in the unet model
@@ -114,7 +130,7 @@ def get_attn_index(module_name):
114130
target_modules.append(module)
115131

116132
elif len(pag_layer_input_splits) == 2:
117-
# when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
133+
# when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
118134
block_type = pag_layer_input_splits[0]
119135
block_index = pag_layer_input_splits[1]
120136
for name, module in self.unet.named_modules():
@@ -126,7 +142,8 @@ def get_attn_index(module_name):
126142
target_modules.append(module)
127143

128144
elif len(pag_layer_input_splits) == 3:
129-
# when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1"
145+
# when the layer input contains block_type, block_index and attention_index.
146+
# e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1"
130147
block_type = pag_layer_input_splits[0]
131148
block_index = pag_layer_input_splits[1]
132149
attn_index = pag_layer_input_splits[2]

0 commit comments

Comments
 (0)