Skip to content

Commit 44d6cc0

Browse files
tastelikefeetJintao-Huang
authored andcommitted
fix ulysses with vl (#5391)
1 parent eb0b03a commit 44d6cc0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def pre_forward_split_hook(_self, args, kwargs):
240240

241241
def local_flash_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
242242
dist_attn, **kwargs):
243-
if module not in text_model.modules():
243+
if module.__class__ not in [m.__class__ for m in text_model.modules()]:
244244
return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query_states, key_states,
245245
value_states, attention_mask, *args,
246246
**kwargs)
@@ -261,7 +261,7 @@ def _attention(query, key, value, *args, **kwargs):
261261

262262
def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
263263
dist_attn, **kwargs):
264-
if module not in text_model.modules():
264+
if module.__class__ not in [m.__class__ for m in text_model.modules()]:
265265
return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query_states, key_states, value_states,
266266
attention_mask, *args, **kwargs)
267267
if dist_attn.local_attn is None:

0 commit comments

Comments
 (0)