diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 7cc30e47ab14..6a05aac215c6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -110,6 +110,27 @@ else: xops = None +# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 +if torch.__version__ >= "2.4.0": + _custom_op = torch.library.custom_op + _register_fake = torch.library.register_fake +else: + + def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + + return wrap if fn is None else fn + + def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + + return wrap if fn is None else fn + + _custom_op = custom_op_no_op + _register_fake = register_fake_no_op + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # ===== torch op registrations ===== # Registrations are required for fullgraph tracing compatibility - - -# TODO: library.custom_op and register_fake probably need version guards? # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") + + +@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3_original( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original( return out, lse -@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +@_register_fake("flash_attn_3::_flash_attn_forward") def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, num_heads, head_dim = query.shape lse_shape = (batch_size, seq_len, num_heads)