Skip to content

Commit ec85d2c

Browse files
authored
Avoid CUDA stream sync (#40060)
Signed-off-by: cyy <[email protected]>
1 parent c7afaa5 commit ec85d2c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,12 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool =
354354
max_length_q = int(q_len.max())
355355
max_length_k = int(last_position_ids.max()) + 1
356356
else:
357-
position_ids = position_ids.flatten()
358-
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
357+
position_ids = position_ids.view(-1)
358+
indices_q = (position_ids == 0).nonzero().view(-1)
359359

360360
cu_seq_lens_q = torch.cat(
361361
(
362-
indices_q[position_ids == 0],
362+
indices_q,
363363
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
364364
)
365365
)

0 commit comments

Comments
 (0)