Skip to content
Merged
26 changes: 21 additions & 5 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,24 @@
import xformers.ops as xops
else:
xops = None

# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn

def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn

_custom_op = custom_op_no_op
_register_fake = register_fake_no_op


logger = get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -473,12 +491,10 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):

# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility


# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")

@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -487,7 +503,7 @@ def _wrapped_flash_attn_3_original(
return out, lse


@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
@_register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
Expand Down
Loading