@@ -181,7 +181,8 @@ def __init__(self,
181181
182182 self .qkv_proj = QKVParallelLinear (
183183 self .d_model ,
184- self .d_model // self .total_num_heads ,
184+ #self.d_model // self.total_num_heads,
185+ self .key_value_proj_dim ,
185186 self .total_num_heads ,
186187 self .total_num_kv_heads ,
187188 bias = False ,
@@ -199,7 +200,8 @@ def __init__(self,
199200 padding_size = self .relative_attention_num_buckets ,
200201 quant_config = quant_config )
201202 self .o = RowParallelLinear (
202- self .d_model ,
203+ #self.d_model,
204+ self .total_num_heads * self .key_value_proj_dim ,
203205 self .d_model ,
204206 bias = False ,
205207 quant_config = quant_config ,
@@ -298,10 +300,12 @@ def forward(
298300 ) -> torch .Tensor :
299301 bs , seq_len , _ = hidden_states .shape
300302 num_seqs = bs
301- n , c = self .n_heads , self .d_model // self .total_num_heads
303+ #n, c = self.n_heads, self.d_model // self.total_num_heads
304+ n , c = self .n_heads , self .key_value_proj_dim
302305 qkv , _ = self .qkv_proj (hidden_states )
303306 # Projection of 'own' hidden state (self-attention). No GQA here.
304- q , k , v = qkv .split (self .inner_dim , dim = - 1 )
307+ #q, k, v = qkv.split(self.inner_dim, dim=-1)
308+ q , k , v = qkv .split (self .qkv_proj .output_sizes , dim = - 1 )
305309 q = q .reshape (bs , seq_len , n , c )
306310 k = k .reshape (bs , seq_len , n , c )
307311 v = v .reshape (bs , seq_len , n , c )
0 commit comments