Skip to content

Commit 1c10a60

Browse files
authored
Merge branch 'main' into modular-diffusers/refactor-guider-outputs
2 parents e570e59 + 9a7ae77 commit 1c10a60

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,27 @@
110110
else:
111111
xops = None
112112

113+
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
114+
if torch.__version__ >= "2.4.0":
115+
_custom_op = torch.library.custom_op
116+
_register_fake = torch.library.register_fake
117+
else:
118+
119+
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
120+
def wrap(func):
121+
return func
122+
123+
return wrap if fn is None else fn
124+
125+
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
126+
def wrap(func):
127+
return func
128+
129+
return wrap if fn is None else fn
130+
131+
_custom_op = custom_op_no_op
132+
_register_fake = register_fake_no_op
133+
113134

114135
logger = get_logger(__name__) # pylint: disable=invalid-name
115136

@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
473494

474495
# ===== torch op registrations =====
475496
# Registrations are required for fullgraph tracing compatibility
476-
477-
478-
# TODO: library.custom_op and register_fake probably need version guards?
479497
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
480498
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
481-
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
499+
500+
501+
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
482502
def _wrapped_flash_attn_3_original(
483503
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
484504
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
487507
return out, lse
488508

489509

490-
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
510+
@_register_fake("flash_attn_3::_flash_attn_forward")
491511
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
492512
batch_size, seq_len, num_heads, head_dim = query.shape
493513
lse_shape = (batch_size, seq_len, num_heads)

0 commit comments

Comments
 (0)