Skip to content

Commit 2c00cbd

Browse files
committed
use version check for torch_xla
1 parent 6c74c79 commit 2c00cbd

File tree

3 files changed

+34
-14
lines changed

3 files changed

+34
-14
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..image_processor import IPAdapterMaskProcessor
2323
from ..utils import deprecate, is_torch_xla_available, logging
24-
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
24+
from ..utils.import_utils import is_torch_npu_available, is_xformers_available, is_torch_xla_version
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

2727

@@ -37,8 +37,10 @@
3737
xformers = None
3838

3939
if is_torch_xla_available():
40-
from torch_xla.experimental.custom_kernel import flash_attention
41-
40+
# flash attention pallas kernel is introduced in the torch_xla 2.3 release.
41+
if is_torch_xla_version(">", "2.2"):
42+
from torch_xla.runtime import is_spmd
43+
from torch_xla.experimental.custom_kernel import flash_attention
4244
XLA_AVAILABLE = True
4345
else:
4446
XLA_AVAILABLE = False
@@ -276,16 +278,21 @@ def __init__(
276278
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
277279
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
278280
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
279-
# If torch_xla is available, we use pallas flash attention kernel to improve the performance.
281+
# If torch_xla is available with the correct version, we use pallas flash attention kernel to improve
282+
# the performance.
280283
if processor is None:
281284
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
282-
if is_torch_xla_available:
285+
if (
286+
is_torch_xla_available
287+
and is_torch_xla_version('>', '2.2')
288+
and (not is_spmd() or is_torch_xla_version('>', '2.3'))
289+
):
283290
processor = XLAFlashAttnProcessor2_0()
284291
else:
285292
processor = AttnProcessor2_0()
286293
else:
287294
processor = AttnProcessor()
288-
295+
289296
self.set_processor(processor)
290297

291298
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
@@ -2771,8 +2778,10 @@ class XLAFlashAttnProcessor2_0:
27712778
def __init__(self):
27722779
if not hasattr(F, "scaled_dot_product_attention"):
27732780
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2774-
if not is_torch_xla_available:
2775-
raise ImportError("XLAFlashAttnProcessor2_0 required torch_xla package.")
2781+
if is_torch_xla_version("<", "2.3"):
2782+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
2783+
if is_spmd() and is_torch_xla_version("<", "2.4"):
2784+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
27762785

27772786
def __call__(
27782787
self,
@@ -2784,10 +2793,6 @@ def __call__(
27842793
*args,
27852794
**kwargs,
27862795
) -> torch.Tensor:
2787-
if len(args) > 0 or kwargs.get("scale", None) is not None:
2788-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2789-
deprecate("scale", "1.0.0", deprecation_message)
2790-
27912796
residual = hidden_states
27922797
if attn.spatial_norm is not None:
27932798
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -2836,7 +2841,7 @@ def __call__(
28362841

28372842
# the output of sdp = (batch, num_heads, seq_len, head_dim)
28382843
# TODO: add support for attn.scale when we move to Torch 2.1
2839-
if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
2844+
if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
28402845
if attention_mask is not None:
28412846
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
28422847
# Convert mask to float and replace 0s with -inf and 1s with 0
@@ -2849,7 +2854,8 @@ def __call__(
28492854
# Apply attention mask to key
28502855
key = key + attention_mask
28512856
query /= math.sqrt(query.shape[3])
2852-
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=("data", None, None, None))
2857+
partition_spec = ("data", None, None, None) if is_spmd() else None
2858+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
28532859
else:
28542860
hidden_states = F.scaled_dot_product_attention(
28552861
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
is_torch_npu_available,
8787
is_torch_version,
8888
is_torch_xla_available,
89+
is_torch_xla_version,
8990
is_torchsde_available,
9091
is_torchvision_available,
9192
is_transformers_available,

src/diffusers/utils/import_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,19 @@ def is_torch_version(operation: str, version: str):
700700
return compare_versions(parse(_torch_version), operation, version)
701701

702702

703+
def is_torch_xla_version(operation: str, version: str):
704+
"""
705+
Compares the current torch_xla version to a given reference with an operation.
706+
707+
Args:
708+
operation (`str`):
709+
A string representation of an operator, such as `">"` or `"<="`
710+
version (`str`):
711+
A string version of torch_xla
712+
"""
713+
return compare_versions(parse(_torch_xla_version), operation, version)
714+
715+
703716
def is_transformers_version(operation: str, version: str):
704717
"""
705718
Compares the current Transformers version to a given reference with an operation.

0 commit comments

Comments
 (0)