@@ -126,8 +126,6 @@ def __init__(
126126 self .head_dim = head_dim
127127 self .max_seq_len = max_seq_len
128128 self .is_causal = is_causal
129- # Number of queries per k, v
130- self .q_per_kv = self .num_heads // self .num_kv_heads
131129
132130 # Set layers
133131 self .kv_cache = kv_cache
@@ -145,7 +143,7 @@ def __init__(
145143 num_kv_heads = self .num_kv_heads ,
146144 num_heads = self .num_heads ,
147145 head_dim = self .head_dim ,
148- q_per_kv = self .q_per_kv ,
146+ q_per_kv = self .num_heads // self . num_kv_heads ,
149147 attn_dropout = self .attn_dropout if self .training else 0.0 ,
150148 is_causal = self .is_causal ,
151149 attention_fn = self ._attention_call ,
@@ -239,7 +237,10 @@ def forward(
239237
240238 # q has shape [b, s_x, num_heads * head_dim]
241239 q = self .q_proj (x )
242- q = q .view (b , s_x , self .num_kv_heads * self .q_per_kv , self .head_dim )
240+
241+ # number of queries per key/value
242+ q_per_kv = self .num_heads // self .num_kv_heads
243+ q = q .view (b , s_x , self .num_kv_heads * q_per_kv , self .head_dim )
243244
244245 # Apply positional embeddings
245246 if self .pos_embeddings is not None :
0 commit comments