Skip to content

Commit b8b165e

Browse files
committed
sd3 xformers quality
1 parent 2cb5ecf commit b8b165e

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,13 +1530,13 @@ def __init__(self, attention_op: Optional[Callable] = None):
15301530
self.attention_op = attention_op
15311531

15321532
def __call__(
1533-
self,
1534-
attn: Attention,
1535-
hidden_states: torch.FloatTensor,
1536-
encoder_hidden_states: torch.FloatTensor = None,
1537-
attention_mask: Optional[torch.FloatTensor] = None,
1538-
*args,
1539-
**kwargs,
1533+
self,
1534+
attn: Attention,
1535+
hidden_states: torch.FloatTensor,
1536+
encoder_hidden_states: torch.FloatTensor = None,
1537+
attention_mask: Optional[torch.FloatTensor] = None,
1538+
*args,
1539+
**kwargs,
15401540
) -> torch.FloatTensor:
15411541
residual = hidden_states
15421542

@@ -1579,7 +1579,7 @@ def __call__(
15791579
# Split the attention outputs.
15801580
hidden_states, encoder_hidden_states = (
15811581
hidden_states[:, : residual.shape[1]],
1582-
hidden_states[:, residual.shape[1]:],
1582+
hidden_states[:, residual.shape[1] :],
15831583
)
15841584

15851585
# linear proj

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,5 @@ def test_xformers_enable_works(self):
8989
model.enable_xformers_memory_efficient_attention()
9090

9191
assert (
92-
model.transformer_blocks[0].attn.processor.__class__.__name__
93-
== "XFormersJointAttnProcessor"
92+
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
9493
), "xformers is not enabled"

0 commit comments

Comments
 (0)