|  | 
| 109 | 109 |     import xformers.ops as xops | 
| 110 | 110 | else: | 
| 111 | 111 |     xops = None | 
|  | 112 | +     | 
|  | 113 | +# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 | 
|  | 114 | +if torch.__version__ >= "2.4.0": | 
|  | 115 | +    _custom_op = torch.library.custom_op | 
|  | 116 | +    _register_fake = torch.library.register_fake | 
|  | 117 | +else: | 
|  | 118 | +    def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): | 
|  | 119 | +        def wrap(func): | 
|  | 120 | +            return func | 
|  | 121 | +        return wrap if fn is None else fn | 
|  | 122 | +     | 
|  | 123 | +    def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): | 
|  | 124 | +        def wrap(func): | 
|  | 125 | +            return func | 
|  | 126 | +        return wrap if fn is None else fn | 
|  | 127 | +     | 
|  | 128 | +    _custom_op = custom_op_no_op | 
|  | 129 | +    _register_fake = register_fake_no_op     | 
| 112 | 130 | 
 | 
| 113 | 131 | 
 | 
| 114 | 132 | logger = get_logger(__name__)  # pylint: disable=invalid-name | 
| @@ -473,28 +491,9 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): | 
| 473 | 491 | 
 | 
| 474 | 492 | # ===== torch op registrations ===== | 
| 475 | 493 | # Registrations are required for fullgraph tracing compatibility | 
| 476 |  | -# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 | 
| 477 | 494 | # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding | 
| 478 | 495 | # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 | 
| 479 | 496 | 
 | 
| 480 |  | -if torch.__version__ >= "2.4.0": | 
| 481 |  | -    _custom_op = torch.library.custom_op | 
| 482 |  | -    _register_fake = torch.library.register_fake | 
| 483 |  | -else: | 
| 484 |  | -    def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): | 
| 485 |  | -        def wrap(func): | 
| 486 |  | -            return func | 
| 487 |  | -        return wrap if fn is None else fn | 
| 488 |  | -     | 
| 489 |  | -    def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): | 
| 490 |  | -        def wrap(func): | 
| 491 |  | -            return func | 
| 492 |  | -        return wrap if fn is None else fn | 
| 493 |  | -     | 
| 494 |  | -    _custom_op = custom_op_no_op | 
| 495 |  | -    _register_fake = register_fake_no_op | 
| 496 |  | - | 
| 497 |  | - | 
| 498 | 497 | @_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") | 
| 499 | 498 | def _wrapped_flash_attn_3_original( | 
| 500 | 499 |     query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | 
|  | 
0 commit comments