Skip to content

Commit 85ca1f9

Browse files
tastelikefeettastelikefeet
authored andcommitted
fix ulysses (#5501)
Co-authored-by: tastelikefeet <[email protected]>
1 parent c7be303 commit 85ca1f9

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,18 @@ def flash_attention_mask(batch_size,
179179

180180
masking_utils.flash_attention_mask = flash_attention_mask
181181
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask
182+
183+
def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs):
184+
input_embeds = torch.ones(
185+
(input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]),
186+
dtype=input_embeds.dtype,
187+
device=input_embeds.device)
188+
cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device)
189+
return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position,
190+
*args, **kwargs)
191+
192+
masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask
193+
masking_utils.create_causal_mask = create_causal_mask
182194
except ImportError:
183195
pass
184196

0 commit comments

Comments
 (0)