110110else :
111111    xops  =  None 
112112
113+ # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 
114+ if  torch .__version__  >=  "2.4.0" :
115+     _custom_op  =  torch .library .custom_op 
116+     _register_fake  =  torch .library .register_fake 
117+ else :
118+ 
119+     def  custom_op_no_op (name , fn = None , / , * , mutates_args , device_types = None , schema = None ):
120+         def  wrap (func ):
121+             return  func 
122+ 
123+         return  wrap  if  fn  is  None  else  fn 
124+ 
125+     def  register_fake_no_op (op , fn = None , / , * , lib = None , _stacklevel = 1 ):
126+         def  wrap (func ):
127+             return  func 
128+ 
129+         return  wrap  if  fn  is  None  else  fn 
130+ 
131+     _custom_op  =  custom_op_no_op 
132+     _register_fake  =  register_fake_no_op 
133+ 
113134
114135logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
115136
@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
473494
474495# ===== torch op registrations ===== 
475496# Registrations are required for fullgraph tracing compatibility 
476- 
477- 
478- # TODO: library.custom_op and register_fake probably need version guards? 
479497# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding 
480498# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 
481- @torch .library .custom_op ("flash_attn_3::_flash_attn_forward" , mutates_args = (), device_types = "cuda" ) 
499+ 
500+ 
501+ @_custom_op ("flash_attn_3::_flash_attn_forward" , mutates_args = (), device_types = "cuda" ) 
482502def  _wrapped_flash_attn_3_original (
483503    query : torch .Tensor , key : torch .Tensor , value : torch .Tensor 
484504) ->  Tuple [torch .Tensor , torch .Tensor ]:
@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
487507    return  out , lse 
488508
489509
490- @torch . library . register_fake ("flash_attn_3::_flash_attn_forward" ) 
510+ @_register_fake ("flash_attn_3::_flash_attn_forward" ) 
491511def  _ (query : torch .Tensor , key : torch .Tensor , value : torch .Tensor ) ->  Tuple [torch .Tensor , torch .Tensor ]:
492512    batch_size , seq_len , num_heads , head_dim  =  query .shape 
493513    lse_shape  =  (batch_size , seq_len , num_heads )
0 commit comments