Skip to content

Commit 03089f5

Browse files
committed
make style
1 parent 0225ed7 commit 03089f5

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def __init__(
284284
)
285285
self.set_processor(processor)
286286

287-
def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None) -> None:
287+
def set_use_xla_flash_attention(
288+
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
289+
) -> None:
288290
r"""
289291
Set whether to use xla flash attention from `torch_xla` or not.
290292
@@ -296,7 +298,7 @@ def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_s
296298
"""
297299
if use_xla_flash_attention:
298300
if not is_torch_xla_available:
299-
raise "torch_xla is not available"
301+
raise "torch_xla is not available"
300302
elif is_torch_xla_version("<", "2.3"):
301303
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
302304
elif is_spmd() and is_torch_xla_version("<", "2.4"):
@@ -2794,12 +2796,14 @@ class XLAFlashAttnProcessor2_0:
27942796

27952797
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
27962798
if not hasattr(F, "scaled_dot_product_attention"):
2797-
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2799+
raise ImportError(
2800+
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2801+
)
27982802
if is_torch_xla_version("<", "2.3"):
27992803
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
28002804
if is_spmd() and is_torch_xla_version("<", "2.4"):
28012805
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
2802-
self.partition_spec=partition_spec
2806+
self.partition_spec = partition_spec
28032807

28042808
def __call__(
28052809
self,
@@ -2875,7 +2879,9 @@ def __call__(
28752879
partition_spec = self.partition_spec if is_spmd() else None
28762880
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
28772881
else:
2878-
logger.warning("Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096.")
2882+
logger.warning(
2883+
"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
2884+
)
28792885
hidden_states = F.scaled_dot_product_attention(
28802886
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
28812887
)

0 commit comments

Comments
 (0)