|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import inspect |
16 | | -from typing import Callable, Dict, List, Optional, Union |
| 16 | +from typing import Any, Callable, Dict, List, Optional, Union |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | from transformers import ( |
@@ -828,6 +828,10 @@ def clip_skip(self): |
828 | 828 | def do_classifier_free_guidance(self): |
829 | 829 | return self._guidance_scale > 1 |
830 | 830 |
|
| 831 | + @property |
| 832 | + def joint_attention_kwargs(self): |
| 833 | + return self._joint_attention_kwargs |
| 834 | + |
831 | 835 | @property |
832 | 836 | def num_timesteps(self): |
833 | 837 | return self._num_timesteps |
@@ -945,6 +949,7 @@ def __call__( |
945 | 949 | ip_adapter_image_embeds: Optional[torch.Tensor] = None, |
946 | 950 | output_type: Optional[str] = "pil", |
947 | 951 | return_dict: bool = True, |
| 952 | + joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
948 | 953 | clip_skip: Optional[int] = None, |
949 | 954 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
950 | 955 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
@@ -1055,6 +1060,10 @@ def __call__( |
1055 | 1060 | return_dict (`bool`, *optional*, defaults to `True`): |
1056 | 1061 | Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of |
1057 | 1062 | a plain tuple. |
| 1063 | + joint_attention_kwargs (`dict`, *optional*): |
| 1064 | + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| 1065 | + `self.processor` in |
| 1066 | + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
1058 | 1067 | callback_on_step_end (`Callable`, *optional*): |
1059 | 1068 | A function that calls at the end of each denoising steps during the inference. The function is called |
1060 | 1069 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, |
@@ -1102,6 +1111,7 @@ def __call__( |
1102 | 1111 |
|
1103 | 1112 | self._guidance_scale = guidance_scale |
1104 | 1113 | self._clip_skip = clip_skip |
| 1114 | + self._joint_attention_kwargs = joint_attention_kwargs |
1105 | 1115 | self._interrupt = False |
1106 | 1116 |
|
1107 | 1117 | # 2. Define call parameters |
@@ -1292,6 +1302,7 @@ def __call__( |
1292 | 1302 | timestep=timestep, |
1293 | 1303 | encoder_hidden_states=prompt_embeds, |
1294 | 1304 | pooled_projections=pooled_prompt_embeds, |
| 1305 | + joint_attention_kwargs=self.joint_attention_kwargs, |
1295 | 1306 | return_dict=False, |
1296 | 1307 | )[0] |
1297 | 1308 |
|
|
0 commit comments