Skip to content

Commit 1182cd5

Browse files
authored
support flashmask (#8670)
* support flashmask * support flashmask
1 parent dae64cc commit 1182cd5

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

paddlenlp/transformers/llama/modeling_pp.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,23 @@ def forward(self, args):
208208
alibi = position_ids
209209
position_ids = attn_mask_startend_row_indices
210210
attn_mask_startend_row_indices = None
211-
elif not self.config.alibi and position_ids is None and attn_mask_startend_row_indices is not None:
212-
# hidden_states, attention_mask, position_ids
213-
position_ids = attn_mask_startend_row_indices
214-
attn_mask_startend_row_indices = None
215-
alibi = None
211+
elif not self.config.alibi:
212+
if get_env_device() in ["gpu"]:
213+
if attention_mask is not None and attention_mask.dtype == paddle.int32:
214+
attention_mask, attn_mask_startend_row_indices, position_ids = (
215+
None,
216+
attention_mask,
217+
attn_mask_startend_row_indices,
218+
)
219+
elif attention_mask is not None and attention_mask.dtype == paddle.int64:
220+
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask
221+
elif (
222+
attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64
223+
):
224+
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
225+
elif position_ids is None and attn_mask_startend_row_indices is not None:
226+
position_ids = attn_mask_startend_row_indices
227+
attn_mask_startend_row_indices = None
216228

217229
has_gradient = not hidden_states.stop_gradient
218230
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:

0 commit comments

Comments
 (0)