Skip to content

Commit 8f5b0e1

Browse files
authored
Merge branch 'main' into low-cpu-mem-usage-lora
2 parents cf4917c + ec9e526 commit 8f5b0e1

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,16 @@ def forward(
8383
hidden_states: torch.FloatTensor,
8484
temb: torch.FloatTensor,
8585
image_rotary_emb=None,
86+
joint_attention_kwargs=None,
8687
):
8788
residual = hidden_states
8889
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
8990
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
90-
91+
joint_attention_kwargs = joint_attention_kwargs or {}
9192
attn_output = self.attn(
9293
hidden_states=norm_hidden_states,
9394
image_rotary_emb=image_rotary_emb,
95+
**joint_attention_kwargs,
9496
)
9597

9698
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
@@ -161,18 +163,20 @@ def forward(
161163
encoder_hidden_states: torch.FloatTensor,
162164
temb: torch.FloatTensor,
163165
image_rotary_emb=None,
166+
joint_attention_kwargs=None,
164167
):
165168
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
166169

167170
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168171
encoder_hidden_states, emb=temb
169172
)
170-
173+
joint_attention_kwargs = joint_attention_kwargs or {}
171174
# Attention.
172175
attn_output, context_attn_output = self.attn(
173176
hidden_states=norm_hidden_states,
174177
encoder_hidden_states=norm_encoder_hidden_states,
175178
image_rotary_emb=image_rotary_emb,
179+
**joint_attention_kwargs,
176180
)
177181

178182
# Process attention outputs for the `hidden_states`.
@@ -497,6 +501,7 @@ def custom_forward(*inputs):
497501
encoder_hidden_states=encoder_hidden_states,
498502
temb=temb,
499503
image_rotary_emb=image_rotary_emb,
504+
joint_attention_kwargs=joint_attention_kwargs,
500505
)
501506

502507
# controlnet residual
@@ -533,6 +538,7 @@ def custom_forward(*inputs):
533538
hidden_states=hidden_states,
534539
temb=temb,
535540
image_rotary_emb=image_rotary_emb,
541+
joint_attention_kwargs=joint_attention_kwargs,
536542
)
537543

538544
# controlnet residual

src/diffusers/pipelines/deepfloyd_if/pipeline_output.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99

1010
@dataclass
1111
class IFPipelineOutput(BaseOutput):
12-
"""
13-
Args:
12+
r"""
1413
Output class for Stable Diffusion pipelines.
15-
images (`List[PIL.Image.Image]` or `np.ndarray`)
14+
15+
Args:
16+
images (`List[PIL.Image.Image]` or `np.ndarray`):
1617
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
1718
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
18-
nsfw_detected (`List[bool]`)
19+
nsfw_detected (`List[bool]`):
1920
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
2021
(nsfw) content or a watermark. `None` if safety checking could not be performed.
21-
watermark_detected (`List[bool]`)
22+
watermark_detected (`List[bool]`):
2223
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
2324
checking could not be performed.
2425
"""

0 commit comments

Comments
 (0)