Skip to content

Commit ec47936

Browse files
Apply style fixes
1 parent 0a3a228 commit ec47936

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,27 @@
109109
import xformers.ops as xops
110110
else:
111111
xops = None
112-
112+
113113
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
114114
if torch.__version__ >= "2.4.0":
115115
_custom_op = torch.library.custom_op
116116
_register_fake = torch.library.register_fake
117117
else:
118+
118119
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
119120
def wrap(func):
120121
return func
122+
121123
return wrap if fn is None else fn
122-
124+
123125
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
124126
def wrap(func):
125127
return func
128+
126129
return wrap if fn is None else fn
127-
130+
128131
_custom_op = custom_op_no_op
129-
_register_fake = register_fake_no_op
132+
_register_fake = register_fake_no_op
130133

131134

132135
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -494,6 +497,7 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
494497
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
495498
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
496499

500+
497501
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
498502
def _wrapped_flash_attn_3_original(
499503
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor

0 commit comments

Comments
 (0)