Skip to content

Commit fb29e37

Browse files
committed
setup the option to use xla flash attention or not
1 parent 5969ce4 commit fb29e37

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

examples/research_projects/pytorch_xla/train_text_to_image_xla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ def main(args):
520520
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
521521

522522
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
523+
unet.enable_use_xla_flash_attention(partition_spec=("data", None, None, None))
523524

524525
vae.requires_grad_(False)
525526
text_encoder.requires_grad_(False)

src/diffusers/models/attention_processor.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -278,21 +278,33 @@ def __init__(
278278
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
279279
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
280280
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
281-
# If torch_xla is available with the correct version, we use pallas flash attention kernel to improve
282-
# the performance.
283281
if processor is None:
284-
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
285-
if (
286-
is_torch_xla_available
287-
and is_torch_xla_version('>', '2.2')
288-
and (not is_spmd() or is_torch_xla_version('>', '2.3'))
289-
):
290-
processor = XLAFlashAttnProcessor2_0()
291-
else:
292-
processor = AttnProcessor2_0()
293-
else:
294-
processor = AttnProcessor()
282+
processor = (
283+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
284+
)
285+
self.set_processor(processor)
286+
287+
def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None) -> None:
288+
r"""
289+
Set whether to use xla flash attention from `torch_xla` or not.
295290
291+
Args:
292+
use_xla_flash_attention (`bool`):
293+
Whether to use pallas flash attention kernel from `torch_xla` or not.
294+
partition_spec (`Tuple[]`, *optional*):
295+
Specify the partition specification if using SPMD. Otherwise None.
296+
"""
297+
if (
298+
use_xla_flash_attention
299+
and is_torch_xla_available
300+
and is_torch_xla_version('>', '2.2')
301+
and (not is_spmd() or is_torch_xla_version('>', '2.3'))
302+
):
303+
processor = XLAFlashAttnProcessor2_0(partition_spec)
304+
else:
305+
processor = (
306+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
307+
)
296308
self.set_processor(processor)
297309

298310
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
@@ -2772,16 +2784,17 @@ def __call__(
27722784

27732785
class XLAFlashAttnProcessor2_0:
27742786
r"""
2775-
Processor for implementing scaled dot-product attention (enabled by default if you're using torch_xla).
2787+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
27762788
"""
27772789

2778-
def __init__(self):
2790+
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
27792791
if not hasattr(F, "scaled_dot_product_attention"):
27802792
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
27812793
if is_torch_xla_version("<", "2.3"):
27822794
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
27832795
if is_spmd() and is_torch_xla_version("<", "2.4"):
27842796
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
2797+
self.partition_spec=partition_spec
27852798

27862799
def __call__(
27872800
self,
@@ -2854,7 +2867,7 @@ def __call__(
28542867
# Apply attention mask to key
28552868
key = key + attention_mask
28562869
query /= math.sqrt(query.shape[3])
2857-
partition_spec = ("data", None, None, None) if is_spmd() else None
2870+
partition_spec = self.partition_spec if is_spmd() else None
28582871
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
28592872
else:
28602873
hidden_states = F.scaled_dot_product_attention(
@@ -5201,6 +5214,7 @@ def __init__(self):
52015214
FusedCogVideoXAttnProcessor2_0,
52025215
XFormersAttnAddedKVProcessor,
52035216
XFormersAttnProcessor,
5217+
XLAFlashAttnProcessor2_0,
52045218
AttnProcessorNPU,
52055219
AttnProcessor2_0,
52065220
MochiVaeAttnProcessor2_0,

src/diffusers/models/modeling_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,35 @@ def disable_npu_flash_attention(self) -> None:
208208
"""
209209
self.set_use_npu_flash_attention(False)
210210

211+
def set_use_xla_flash_attention(
212+
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
213+
) -> None:
214+
# Recursively walk through all the children.
215+
# Any children which exposes the set_use_xla_flash_attention method
216+
# gets the message
217+
def fn_recursive_set_flash_attention(module: torch.nn.Module):
218+
if hasattr(module, "set_use_xla_flash_attention"):
219+
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
220+
221+
for child in module.children():
222+
fn_recursive_set_flash_attention(child)
223+
224+
for module in self.children():
225+
if isinstance(module, torch.nn.Module):
226+
fn_recursive_set_flash_attention(module)
227+
228+
def enable_use_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
229+
r"""
230+
Enable the flash attention pallals kernel for torch_xla.
231+
"""
232+
self.set_use_xla_flash_attention(True, partition_spec)
233+
234+
def disable_use_xla_flash_attention(self):
235+
r"""
236+
Disable the flash attention pallals kernel for torch_xla.
237+
"""
238+
self.set_use_xla_flash_attention(False)
239+
211240
def set_use_memory_efficient_attention_xformers(
212241
self, valid: bool, attention_op: Optional[Callable] = None
213242
) -> None:

0 commit comments

Comments
 (0)