Skip to content

Commit dbe4725

Browse files
committed
adding warning message
1 parent ff332e6 commit dbe4725

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,15 @@ def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_s
294294
partition_spec (`Tuple[]`, *optional*):
295295
Specify the partition specification if using SPMD. Otherwise None.
296296
"""
297-
if (
298-
use_xla_flash_attention
299-
and is_torch_xla_available
300-
and is_torch_xla_version('>', '2.2')
301-
and (not is_spmd() or is_torch_xla_version('>', '2.3'))
302-
):
303-
processor = XLAFlashAttnProcessor2_0(partition_spec)
297+
if use_xla_flash_attention:
298+
if not is_torch_xla_available:
299+
raise "torch_xla is not available"
300+
elif is_torch_xla_version("<", "2.3"):
301+
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
302+
elif is_spmd() and is_torch_xla_version("<", "2.4"):
303+
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
304+
else:
305+
processor = XLAFlashAttnProcessor2_0(partition_spec)
304306
else:
305307
processor = (
306308
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
@@ -2871,6 +2873,7 @@ def __call__(
28712873
partition_spec = self.partition_spec if is_spmd() else None
28722874
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
28732875
else:
2876+
logger.warning(f"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096.")
28742877
hidden_states = F.scaled_dot_product_attention(
28752878
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
28762879
)

src/diffusers/utils/import_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,8 @@ def is_torch_xla_version(operation: str, version: str):
710710
version (`str`):
711711
A string version of torch_xla
712712
"""
713+
if not is_torch_xla_available:
714+
return False
713715
return compare_versions(parse(_torch_xla_version), operation, version)
714716

715717

0 commit comments

Comments
 (0)