@@ -299,7 +299,7 @@ def scaled_dot_product_attention(
299
299
300
300
def masked_fill (x , mask , value ):
301
301
y = paddle .full (x .shape , value , x .dtype )
302
- return paddle .where (mask , y , x )
302
+ return paddle .where (mask . to ( "bool" ) , y , x )
303
303
304
304
305
305
def is_casual_mask (attention_mask ):
@@ -519,7 +519,7 @@ def forward(self, hidden_states):
519
519
# this will be used to easily index which expert is going to be sollicitated.
520
520
# shape: [num_experts, top_k, batch_size * seq_len]
521
521
expert_mask = F .one_hot (selected_experts , num_classes = self .num_experts ).transpose ([2 , 1 , 0 ])
522
-
522
+ expert_mask = expert_mask . to ( "bool" )
523
523
# Loop over all available experts in the model and perform the computation on each expert.
524
524
for expert_id in range (self .num_experts ):
525
525
expert_layer = self .experts [expert_id ]
@@ -1098,7 +1098,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
1098
1098
past_key_values_length = past_key_values_length ,
1099
1099
)
1100
1100
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
1101
- expanded_attn_mask = paddle .where (expanded_attn_mask , 0.0 , paddle .finfo (dtype ).min ).astype (dtype )
1101
+ expanded_attn_mask = paddle .where (expanded_attn_mask . to ( "bool" ) , 0.0 , paddle .finfo (dtype ).min ).astype (dtype )
1102
1102
return expanded_attn_mask
1103
1103
1104
1104
@paddle .jit .not_to_static
0 commit comments