Skip to content

Commit 5ef7da4

Browse files
committed
Move version check to top of file and use private naming as requested
1 parent f515990 commit 5ef7da4

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,24 @@
109109
import xformers.ops as xops
110110
else:
111111
xops = None
112+
113+
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
114+
if torch.__version__ >= "2.4.0":
115+
_custom_op = torch.library.custom_op
116+
_register_fake = torch.library.register_fake
117+
else:
118+
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
119+
def wrap(func):
120+
return func
121+
return wrap if fn is None else fn
122+
123+
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
124+
def wrap(func):
125+
return func
126+
return wrap if fn is None else fn
127+
128+
_custom_op = custom_op_no_op
129+
_register_fake = register_fake_no_op
112130

113131

114132
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -473,28 +491,9 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
473491

474492
# ===== torch op registrations =====
475493
# Registrations are required for fullgraph tracing compatibility
476-
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
477494
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
478495
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
479496

480-
if torch.__version__ >= "2.4.0":
481-
_custom_op = torch.library.custom_op
482-
_register_fake = torch.library.register_fake
483-
else:
484-
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
485-
def wrap(func):
486-
return func
487-
return wrap if fn is None else fn
488-
489-
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
490-
def wrap(func):
491-
return func
492-
return wrap if fn is None else fn
493-
494-
_custom_op = custom_op_no_op
495-
_register_fake = register_fake_no_op
496-
497-
498497
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
499498
def _wrapped_flash_attn_3_original(
500499
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor

0 commit comments

Comments
 (0)