@@ -208,11 +208,23 @@ def forward(self, args):
208
208
alibi = position_ids
209
209
position_ids = attn_mask_startend_row_indices
210
210
attn_mask_startend_row_indices = None
211
- elif not self .config .alibi and position_ids is None and attn_mask_startend_row_indices is not None :
212
- # hidden_states, attention_mask, position_ids
213
- position_ids = attn_mask_startend_row_indices
214
- attn_mask_startend_row_indices = None
215
- alibi = None
211
+ elif not self .config .alibi :
212
+ if get_env_device () in ["gpu" ]:
213
+ if attention_mask is not None and attention_mask .dtype == paddle .int32 :
214
+ attention_mask , attn_mask_startend_row_indices , position_ids = (
215
+ None ,
216
+ attention_mask ,
217
+ attn_mask_startend_row_indices ,
218
+ )
219
+ elif attention_mask is not None and attention_mask .dtype == paddle .int64 :
220
+ attention_mask , attn_mask_startend_row_indices , position_ids = None , None , attention_mask
221
+ elif (
222
+ attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices .dtype == paddle .int64
223
+ ):
224
+ attn_mask_startend_row_indices , position_ids = None , attn_mask_startend_row_indices
225
+ elif position_ids is None and attn_mask_startend_row_indices is not None :
226
+ position_ids = attn_mask_startend_row_indices
227
+ attn_mask_startend_row_indices = None
216
228
217
229
has_gradient = not hidden_states .stop_gradient
218
230
if self .enable_recompute and self .config .recompute_granularity == "full" and has_gradient :
0 commit comments