File tree Expand file tree Collapse file tree 1 file changed +8
-9
lines changed
src/prime_rl/trainer/models/afmoe Expand file tree Collapse file tree 1 file changed +8
-9
lines changed Original file line number Diff line number Diff line change @@ -443,16 +443,15 @@ def forward(
443443 cache_position = torch .arange (inputs_embeds .shape [1 ], device = inputs_embeds .device )
444444
445445 if self .config ._attn_implementation in ("flash_attention_2" , "flash_attention_3" ):
446- flat_position_ids = position_ids . view ( - 1 )
447- seqlens = torch .cat (
448- [
449- flat_position_ids [ 0 : 1 ] ,
450- flat_position_ids [: - 1 ][( flat_position_ids == 0 )[ 1 :]] + 1 ,
451- flat_position_ids [ - 1 :] + 1 ,
452- ]
446+ batch_size , seq_len = inputs_embeds . shape [: 2 ]
447+ cu_seqlens = torch .arange (
448+ 0 ,
449+ ( batch_size + 1 ) * seq_len ,
450+ step = seq_len ,
451+ dtype = torch . int32 ,
452+ device = inputs_embeds . device ,
453453 )
454- max_seqlen = seqlens .max ().item ()
455- cu_seqlens = seqlens .cumsum (dim = 0 , dtype = torch .int32 )
454+ max_seqlen = seq_len
456455 torch ._dynamo .mark_dynamic (cu_seqlens , 0 )
457456 else :
458457 max_seqlen = None
You can’t perform that action at this time.
0 commit comments