@@ -2962,12 +2962,6 @@ def __call__(
2962
2962
# perturbed path (identity attention)
2963
2963
batch_size , sequence_length , _ = hidden_states_ptb .shape
2964
2964
2965
- if attention_mask is not None :
2966
- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
2967
- # scaled_dot_product_attention expects attention_mask shape to be
2968
- # (batch, heads, source_length, target_length)
2969
- attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
2970
-
2971
2965
if attn .group_norm is not None :
2972
2966
hidden_states_ptb = attn .group_norm (hidden_states_ptb .transpose (1 , 2 )).transpose (1 , 2 )
2973
2967
@@ -3070,12 +3064,6 @@ def __call__(
3070
3064
# perturbed path (identity attention)
3071
3065
batch_size , sequence_length , _ = hidden_states_ptb .shape
3072
3066
3073
- if attention_mask is not None :
3074
- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
3075
- # scaled_dot_product_attention expects attention_mask shape to be
3076
- # (batch, heads, source_length, target_length)
3077
- attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
3078
-
3079
3067
if attn .group_norm is not None :
3080
3068
hidden_states_ptb = attn .group_norm (hidden_states_ptb .transpose (1 , 2 )).transpose (1 , 2 )
3081
3069
0 commit comments