Skip to content
Merged
27 changes: 22 additions & 5 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):

# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility


# TODO: library.custom_op and register_fake probably need version guards?
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")

if torch.__version__ >= "2.4.0":
custom_op = torch.library.custom_op
register_fake = torch.library.register_fake
else:
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn

def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn

custom_op = custom_op_no_op
register_fake = register_fake_no_op


@custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -487,7 +504,7 @@ def _wrapped_flash_attn_3_original(
return out, lse


@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
@register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
Expand Down