@@ -240,7 +240,7 @@ def pre_forward_split_hook(_self, args, kwargs):
240240
241241 def local_flash_attn (module : torch .nn .Module , query_states , key_states , value_states , attention_mask , * args ,
242242 dist_attn , ** kwargs ):
243- if module not in text_model .modules ():
243+ if module . __class__ not in [ m . __class__ for m in text_model .modules ()] :
244244 return ALL_ATTENTION_FUNCTIONS ['flash_attention_2_origin' ](module , query_states , key_states ,
245245 value_states , attention_mask , * args ,
246246 ** kwargs )
@@ -261,7 +261,7 @@ def _attention(query, key, value, *args, **kwargs):
261261
262262 def local_sdpa_attn (module : torch .nn .Module , query_states , key_states , value_states , attention_mask , * args ,
263263 dist_attn , ** kwargs ):
264- if module not in text_model .modules ():
264+ if module . __class__ not in [ m . __class__ for m in text_model .modules ()] :
265265 return ALL_ATTENTION_FUNCTIONS ['sdpa_origin' ](module , query_states , key_states , value_states ,
266266 attention_mask , * args , ** kwargs )
267267 if dist_attn .local_attn is None :
0 commit comments