Skip to content

Commit 9886a41

Browse files
cursoragentsami
andcommitted
Align AFMoE flash attn with sliding window
Co-authored-by: sami <[email protected]>
1 parent 94ff2aa commit 9886a41

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/prime_rl/trainer/models/afmoe/modeling_afmoe.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -443,16 +443,15 @@ def forward(
443443
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
444444

445445
if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3"):
446-
flat_position_ids = position_ids.view(-1)
447-
seqlens = torch.cat(
448-
[
449-
flat_position_ids[0:1],
450-
flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1,
451-
flat_position_ids[-1:] + 1,
452-
]
446+
batch_size, seq_len = inputs_embeds.shape[:2]
447+
cu_seqlens = torch.arange(
448+
0,
449+
(batch_size + 1) * seq_len,
450+
step=seq_len,
451+
dtype=torch.int32,
452+
device=inputs_embeds.device,
453453
)
454-
max_seqlen = seqlens.max().item()
455-
cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32)
454+
max_seqlen = seq_len
456455
torch._dynamo.mark_dynamic(cu_seqlens, 0)
457456
else:
458457
max_seqlen = None

0 commit comments

Comments
 (0)