Skip to content

Commit 59c8552

Browse files
authored
Merge branch 'main' into hunyuan-video
2 parents d0c61e0 + bdbaea8 commit 59c8552

12 files changed

+389
-363
lines changed

examples/community/README_community_scripts.md

Lines changed: 149 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline
241241
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
242242
from diffusers.configuration_utils import register_to_config
243243
import torch
244-
from typing import Any, Dict, Optional
244+
from typing import Any, Dict, Tuple, Union
245+
246+
247+
class SDPromptSchedulingCallback(PipelineCallback):
248+
@register_to_config
249+
def __init__(
250+
self,
251+
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
252+
cutoff_step_ratio=None,
253+
cutoff_step_index=None,
254+
):
255+
super().__init__(
256+
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
257+
)
258+
259+
tensor_inputs = ["prompt_embeds"]
260+
261+
def callback_fn(
262+
self, pipeline, step_index, timestep, callback_kwargs
263+
) -> Dict[str, Any]:
264+
cutoff_step_ratio = self.config.cutoff_step_ratio
265+
cutoff_step_index = self.config.cutoff_step_index
266+
if isinstance(self.config.encoded_prompt, tuple):
267+
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
268+
else:
269+
prompt_embeds = self.config.encoded_prompt
270+
271+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
272+
cutoff_step = (
273+
cutoff_step_index
274+
if cutoff_step_index is not None
275+
else int(pipeline.num_timesteps * cutoff_step_ratio)
276+
)
277+
278+
if step_index == cutoff_step:
279+
if pipeline.do_classifier_free_guidance:
280+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
281+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
282+
return callback_kwargs
245283

246284

247285
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
@@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
253291
pipeline.safety_checker = None
254292
pipeline.requires_safety_checker = False
255293

294+
callback = MultiPipelineCallbacks(
295+
[
296+
SDPromptSchedulingCallback(
297+
encoded_prompt=pipeline.encode_prompt(
298+
prompt=f"prompt {index}",
299+
negative_prompt=f"negative prompt {index}",
300+
device=pipeline._execution_device,
301+
num_images_per_prompt=1,
302+
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
303+
do_classifier_free_guidance=True,
304+
),
305+
cutoff_step_index=index,
306+
) for index in range(1, 20)
307+
]
308+
)
309+
310+
image = pipeline(
311+
prompt="prompt"
312+
negative_prompt="negative prompt",
313+
callback_on_step_end=callback,
314+
callback_on_step_end_tensor_inputs=["prompt_embeds"],
315+
).images[0]
316+
torch.cuda.empty_cache()
317+
image.save('image.png')
318+
```
256319

257-
class SDPromptScheduleCallback(PipelineCallback):
320+
```python
321+
from diffusers import StableDiffusionXLPipeline
322+
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
323+
from diffusers.configuration_utils import register_to_config
324+
import torch
325+
from typing import Any, Dict, Tuple, Union
326+
327+
328+
class SDXLPromptSchedulingCallback(PipelineCallback):
258329
@register_to_config
259330
def __init__(
260331
self,
261-
prompt: str,
262-
negative_prompt: Optional[str] = None,
263-
num_images_per_prompt: int = 1,
264-
cutoff_step_ratio=1.0,
332+
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
333+
add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
334+
add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
335+
cutoff_step_ratio=None,
265336
cutoff_step_index=None,
266337
):
267338
super().__init__(
268339
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
269340
)
270341

271-
tensor_inputs = ["prompt_embeds"]
342+
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
272343

273344
def callback_fn(
274345
self, pipeline, step_index, timestep, callback_kwargs
275346
) -> Dict[str, Any]:
276347
cutoff_step_ratio = self.config.cutoff_step_ratio
277348
cutoff_step_index = self.config.cutoff_step_index
349+
if isinstance(self.config.encoded_prompt, tuple):
350+
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
351+
else:
352+
prompt_embeds = self.config.encoded_prompt
353+
if isinstance(self.config.add_text_embeds, tuple):
354+
add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds
355+
else:
356+
add_text_embeds = self.config.add_text_embeds
357+
if isinstance(self.config.add_time_ids, tuple):
358+
add_time_ids, negative_add_time_ids = self.config.add_time_ids
359+
else:
360+
add_time_ids = self.config.add_time_ids
278361

279362
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
280363
cutoff_step = (
@@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback):
284367
)
285368

286369
if step_index == cutoff_step:
287-
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
288-
prompt=self.config.prompt,
289-
negative_prompt=self.config.negative_prompt,
290-
device=pipeline._execution_device,
291-
num_images_per_prompt=self.config.num_images_per_prompt,
292-
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
293-
)
294370
if pipeline.do_classifier_free_guidance:
295371
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
372+
add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
373+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
296374
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
375+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
376+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
297377
return callback_kwargs
298378

299-
callback = MultiPipelineCallbacks(
300-
[
301-
SDPromptScheduleCallback(
302-
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
303-
negative_prompt="Deformed, ugly, bad anatomy",
304-
cutoff_step_ratio=0.25,
379+
380+
pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
381+
"stabilityai/stable-diffusion-xl-base-1.0",
382+
torch_dtype=torch.float16,
383+
variant="fp16",
384+
use_safetensors=True,
385+
).to("cuda")
386+
387+
callbacks = []
388+
for index in range(1, 20):
389+
(
390+
prompt_embeds,
391+
negative_prompt_embeds,
392+
pooled_prompt_embeds,
393+
negative_pooled_prompt_embeds,
394+
) = pipeline.encode_prompt(
395+
prompt=f"prompt {index}",
396+
negative_prompt=f"prompt {index}",
397+
device=pipeline._execution_device,
398+
num_images_per_prompt=1,
399+
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
400+
do_classifier_free_guidance=True,
401+
)
402+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
403+
add_time_ids = pipeline._get_add_time_ids(
404+
(1024, 1024),
405+
(0, 0),
406+
(1024, 1024),
407+
dtype=prompt_embeds.dtype,
408+
text_encoder_projection_dim=text_encoder_projection_dim,
409+
)
410+
negative_add_time_ids = pipeline._get_add_time_ids(
411+
(1024, 1024),
412+
(0, 0),
413+
(1024, 1024),
414+
dtype=prompt_embeds.dtype,
415+
text_encoder_projection_dim=text_encoder_projection_dim,
416+
)
417+
callbacks.append(
418+
SDXLPromptSchedulingCallback(
419+
encoded_prompt=(prompt_embeds, negative_prompt_embeds),
420+
add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),
421+
add_time_ids=(add_time_ids, negative_add_time_ids),
422+
cutoff_step_index=index,
305423
)
306-
]
307-
)
424+
)
425+
426+
427+
callback = MultiPipelineCallbacks(callbacks)
308428

309429
image = pipeline(
310-
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
311-
negative_prompt="Deformed, ugly, bad anatomy",
430+
prompt="prompt",
431+
negative_prompt="negative prompt",
312432
callback_on_step_end=callback,
313-
callback_on_step_end_tensor_inputs=["prompt_embeds"],
433+
callback_on_step_end_tensor_inputs=[
434+
"prompt_embeds",
435+
"add_text_embeds",
436+
"add_time_ids",
437+
],
314438
).images[0]
315-
torch.cuda.empty_cache()
316-
image.save('image.png')
317439
```

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
SparseControlNetModel,
1616
SparseControlNetOutput,
1717
)
18-
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel
18+
from .controlnet_union import ControlNetUnionModel
1919
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
2020
from .multicontrolnet import MultiControlNetModel
2121

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 8 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from dataclasses import dataclass
1514
from typing import Any, Dict, List, Optional, Tuple, Union
1615

1716
import torch
1817
from torch import nn
1918

2019
from ...configuration_utils import ConfigMixin, register_to_config
21-
from ...image_processor import PipelineImageInput
2220
from ...loaders.single_file_model import FromOriginalModelMixin
2321
from ...utils import logging
2422
from ..attention_processor import (
@@ -40,76 +38,6 @@
4038
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
4139

4240

43-
@dataclass
44-
class ControlNetUnionInput:
45-
"""
46-
The image input of [`ControlNetUnionModel`]:
47-
48-
- 0: openpose
49-
- 1: depth
50-
- 2: hed/pidi/scribble/ted
51-
- 3: canny/lineart/anime_lineart/mlsd
52-
- 4: normal
53-
- 5: segment
54-
"""
55-
56-
openpose: Optional[PipelineImageInput] = None
57-
depth: Optional[PipelineImageInput] = None
58-
hed: Optional[PipelineImageInput] = None
59-
canny: Optional[PipelineImageInput] = None
60-
normal: Optional[PipelineImageInput] = None
61-
segment: Optional[PipelineImageInput] = None
62-
63-
def __len__(self) -> int:
64-
return len(vars(self))
65-
66-
def __iter__(self):
67-
return iter(vars(self))
68-
69-
def __getitem__(self, key):
70-
return getattr(self, key)
71-
72-
def __setitem__(self, key, value):
73-
setattr(self, key, value)
74-
75-
76-
@dataclass
77-
class ControlNetUnionInputProMax:
78-
"""
79-
The image input of [`ControlNetUnionModel`]:
80-
81-
- 0: openpose
82-
- 1: depth
83-
- 2: hed/pidi/scribble/ted
84-
- 3: canny/lineart/anime_lineart/mlsd
85-
- 4: normal
86-
- 5: segment
87-
- 6: tile
88-
- 7: repaint
89-
"""
90-
91-
openpose: Optional[PipelineImageInput] = None
92-
depth: Optional[PipelineImageInput] = None
93-
hed: Optional[PipelineImageInput] = None
94-
canny: Optional[PipelineImageInput] = None
95-
normal: Optional[PipelineImageInput] = None
96-
segment: Optional[PipelineImageInput] = None
97-
tile: Optional[PipelineImageInput] = None
98-
repaint: Optional[PipelineImageInput] = None
99-
100-
def __len__(self) -> int:
101-
return len(vars(self))
102-
103-
def __iter__(self):
104-
return iter(vars(self))
105-
106-
def __getitem__(self, key):
107-
return getattr(self, key)
108-
109-
def __setitem__(self, key, value):
110-
setattr(self, key, value)
111-
112-
11341
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11442

11543

@@ -680,8 +608,9 @@ def forward(
680608
sample: torch.Tensor,
681609
timestep: Union[torch.Tensor, float, int],
682610
encoder_hidden_states: torch.Tensor,
683-
controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
611+
controlnet_cond: List[torch.Tensor],
684612
control_type: torch.Tensor,
613+
control_type_idx: List[int],
685614
conditioning_scale: float = 1.0,
686615
class_labels: Optional[torch.Tensor] = None,
687616
timestep_cond: Optional[torch.Tensor] = None,
@@ -701,11 +630,13 @@ def forward(
701630
The number of timesteps to denoise an input.
702631
encoder_hidden_states (`torch.Tensor`):
703632
The encoder hidden states.
704-
controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
633+
controlnet_cond (`List[torch.Tensor]`):
705634
The conditional input tensors.
706635
control_type (`torch.Tensor`):
707636
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
708637
type is used.
638+
control_type_idx (`List[int]`):
639+
The indices of `control_type`.
709640
conditioning_scale (`float`, defaults to `1.0`):
710641
The scale factor for ControlNet outputs.
711642
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
@@ -733,20 +664,6 @@ def forward(
733664
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
734665
returned where the first element is the sample tensor.
735666
"""
736-
if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)):
737-
raise ValueError(
738-
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
739-
)
740-
if len(controlnet_cond) != self.config.num_control_type:
741-
if isinstance(controlnet_cond, ControlNetUnionInput):
742-
raise ValueError(
743-
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`."
744-
)
745-
elif isinstance(controlnet_cond, ControlNetUnionInputProMax):
746-
raise ValueError(
747-
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`."
748-
)
749-
750667
# check channel order
751668
channel_order = self.config.controlnet_conditioning_channel_order
752669

@@ -830,12 +747,10 @@ def forward(
830747
inputs = []
831748
condition_list = []
832749

833-
for idx, image_type in enumerate(controlnet_cond):
834-
if controlnet_cond[image_type] is None:
835-
continue
836-
condition = self.controlnet_cond_embedding(controlnet_cond[image_type])
750+
for cond, control_idx in zip(controlnet_cond, control_type_idx):
751+
condition = self.controlnet_cond_embedding(cond)
837752
feat_seq = torch.mean(condition, dim=(2, 3))
838-
feat_seq = feat_seq + self.task_embedding[idx]
753+
feat_seq = feat_seq + self.task_embedding[control_idx]
839754
inputs.append(feat_seq.unsqueeze(1))
840755
condition_list.append(condition)
841756

0 commit comments

Comments
 (0)