|  | 
| 109 | 109 |     import xformers.ops as xops | 
| 110 | 110 | else: | 
| 111 | 111 |     xops = None | 
| 112 |  | -     | 
|  | 112 | + | 
| 113 | 113 | # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 | 
| 114 | 114 | if torch.__version__ >= "2.4.0": | 
| 115 | 115 |     _custom_op = torch.library.custom_op | 
| 116 | 116 |     _register_fake = torch.library.register_fake | 
| 117 | 117 | else: | 
|  | 118 | + | 
| 118 | 119 |     def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): | 
| 119 | 120 |         def wrap(func): | 
| 120 | 121 |             return func | 
|  | 122 | + | 
| 121 | 123 |         return wrap if fn is None else fn | 
| 122 |  | -     | 
|  | 124 | + | 
| 123 | 125 |     def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): | 
| 124 | 126 |         def wrap(func): | 
| 125 | 127 |             return func | 
|  | 128 | + | 
| 126 | 129 |         return wrap if fn is None else fn | 
| 127 |  | -     | 
|  | 130 | + | 
| 128 | 131 |     _custom_op = custom_op_no_op | 
| 129 |  | -    _register_fake = register_fake_no_op     | 
|  | 132 | +    _register_fake = register_fake_no_op | 
| 130 | 133 | 
 | 
| 131 | 134 | 
 | 
| 132 | 135 | logger = get_logger(__name__)  # pylint: disable=invalid-name | 
| @@ -494,6 +497,7 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): | 
| 494 | 497 | # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding | 
| 495 | 498 | # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 | 
| 496 | 499 | 
 | 
|  | 500 | + | 
| 497 | 501 | @_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") | 
| 498 | 502 | def _wrapped_flash_attn_3_original( | 
| 499 | 503 |     query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | 
|  | 
0 commit comments