Skip to content

Commit 95cee7d

Browse files
committed
Added joint_attention_kwargs property
1 parent 64f6991 commit 95cee7d

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
from typing import Callable, Dict, List, Optional, Union
16+
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import torch
1919
from transformers import (
@@ -828,6 +828,10 @@ def clip_skip(self):
828828
def do_classifier_free_guidance(self):
829829
return self._guidance_scale > 1
830830

831+
@property
832+
def joint_attention_kwargs(self):
833+
return self._joint_attention_kwargs
834+
831835
@property
832836
def num_timesteps(self):
833837
return self._num_timesteps
@@ -945,6 +949,7 @@ def __call__(
945949
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
946950
output_type: Optional[str] = "pil",
947951
return_dict: bool = True,
952+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
948953
clip_skip: Optional[int] = None,
949954
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
950955
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -1055,6 +1060,10 @@ def __call__(
10551060
return_dict (`bool`, *optional*, defaults to `True`):
10561061
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
10571062
a plain tuple.
1063+
joint_attention_kwargs (`dict`, *optional*):
1064+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1065+
`self.processor` in
1066+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
10581067
callback_on_step_end (`Callable`, *optional*):
10591068
A function that calls at the end of each denoising steps during the inference. The function is called
10601069
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -1102,6 +1111,7 @@ def __call__(
11021111

11031112
self._guidance_scale = guidance_scale
11041113
self._clip_skip = clip_skip
1114+
self._joint_attention_kwargs = joint_attention_kwargs
11051115
self._interrupt = False
11061116

11071117
# 2. Define call parameters
@@ -1292,6 +1302,7 @@ def __call__(
12921302
timestep=timestep,
12931303
encoder_hidden_states=prompt_embeds,
12941304
pooled_projections=pooled_prompt_embeds,
1305+
joint_attention_kwargs=self.joint_attention_kwargs,
12951306
return_dict=False,
12961307
)[0]
12971308

0 commit comments

Comments
 (0)