@@ -37,38 +37,9 @@ def llama_pos_shift_attention_forward(
3737) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
3838 bsz , q_len , _ = hidden_states .size ()
3939
40- if self .config .pretraining_tp > 1 :
41- key_value_slicing = (
42- self .num_key_value_heads * self .head_dim
43- ) // self .config .pretraining_tp
44- query_slices = self .q_proj .weight .split (
45- (self .num_heads * self .head_dim ) // self .config .pretraining_tp , dim = 0
46- )
47- key_slices = self .k_proj .weight .split (key_value_slicing , dim = 0 )
48- value_slices = self .v_proj .weight .split (key_value_slicing , dim = 0 )
49-
50- query_states = [
51- F .linear (hidden_states , query_slices [i ])
52- for i in range (self .config .pretraining_tp )
53- ]
54- query_states = torch .cat (query_states , dim = - 1 )
55-
56- key_states = [
57- F .linear (hidden_states , key_slices [i ])
58- for i in range (self .config .pretraining_tp )
59- ]
60- key_states = torch .cat (key_states , dim = - 1 )
61-
62- value_states = [
63- F .linear (hidden_states , value_slices [i ])
64- for i in range (self .config .pretraining_tp )
65- ]
66- value_states = torch .cat (value_states , dim = - 1 )
67-
68- else :
69- query_states = self .q_proj (hidden_states )
70- key_states = self .k_proj (hidden_states )
71- value_states = self .v_proj (hidden_states )
40+ query_states = self .q_proj (hidden_states )
41+ key_states = self .k_proj (hidden_states )
42+ value_states = self .v_proj (hidden_states )
7243
7344 query_states = query_states .view (
7445 bsz , q_len , self .num_heads , self .head_dim
@@ -103,9 +74,9 @@ def llama_pos_shift_attention_forward(
10374 # repeat k/v heads if n_kv_heads < n_heads
10475 key_states = repeat_kv (key_states , self .num_key_value_groups )
10576 value_states = repeat_kv (value_states , self .num_key_value_groups )
106-
107- attn_weights = torch . matmul ( query_states , key_states . transpose ( 2 , 3 )) / math . sqrt (
108- self . head_dim
77+ softmax_scale = 1.0 / math . sqrt ( self . head_dim )
78+ attn_weights = (
79+ torch . matmul ( query_states , key_states . transpose ( 2 , 3 )) * softmax_scale
10980 )
11081
11182 if attn_weights .size () != (bsz , self .num_heads , q_len , kv_seq_len ):
@@ -114,6 +85,23 @@ def llama_pos_shift_attention_forward(
11485 f" { attn_weights .size ()} "
11586 )
11687
88+ # For causal mode, we use to get input mask, but now causal mode does not expect a mask
89+ # and we need to generate the causal mask ourselves.
90+ current_is_causal = False
91+ if self .is_causal and attention_mask is None and q_len > 1 :
92+ current_is_causal = True
93+ if current_is_causal and attention_mask is None :
94+ bool_attention_mask = torch .ones (
95+ [query_states .shape [- 2 ], key_states .shape [- 2 ]],
96+ device = query_states .device ,
97+ dtype = torch .bool ,
98+ ).tril ()
99+ additive_attention_mask = torch .zeros_like (
100+ bool_attention_mask , dtype = attn_weights .dtype
101+ ).masked_fill (bool_attention_mask .logical_not (), - 10000 )
102+ attn_weights = attn_weights + additive_attention_mask
103+
104+ # Legacy support to take in mask for non-causal mode.
117105 if attention_mask is not None :
118106 if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
119107 raise ValueError (
@@ -132,30 +120,10 @@ def llama_pos_shift_attention_forward(
132120 f"`attn_output` should be of size { (bsz , self .num_heads , q_len , self .head_dim )} , but is"
133121 f" { attn_output .size ()} "
134122 )
135-
136123 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
137124 attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
138-
139- if self .config .pretraining_tp > 1 :
140- attn_output = attn_output .split (
141- self .hidden_size // self .config .pretraining_tp , dim = 2
142- )
143- o_proj_slices = self .o_proj .weight .split (
144- self .hidden_size // self .config .pretraining_tp , dim = 1
145- )
146- attn_output = sum (
147- [
148- F .linear (attn_output [i ], o_proj_slices [i ])
149- for i in range (self .config .pretraining_tp )
150- ]
151- )
152- else :
153- attn_output = self .o_proj (attn_output )
154-
155- if not output_attentions :
156- attn_weights = None
157-
158- return attn_output , attn_weights , past_key_value
125+ attn_output = self .o_proj (attn_output )
126+ return attn_output , None , past_key_value
159127
160128
161129def enable_llama_pos_shift_attention (model ):
0 commit comments