@@ -210,6 +210,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
210210 self .inv_scale = 1.0 / (float (self .head_dim ) ** 0.5 )
211211 self .attention_qkv_bias = config .attention_qkv_bias
212212 self .use_qk_norm = config .use_qk_norm
213+ self .use_conv2d = False
213214
214215 assert not self .use_qk_norm , "QK norm not supported in static attention yet"
215216 self .wqs = nn .ModuleList (
@@ -255,9 +256,25 @@ def forward(
255256 in_cache_state = kwargs .get ("in_cache_state" )
256257 out_cache_state = kwargs .get ("out_cache_state" )
257258
259+ bsz , seq_len , dim = x .shape
260+ if self .use_conv2d :
261+ x = x .reshape (bsz , seq_len , 1 , dim ).transpose (1 , 3 )
262+
258263 new_qs = [self .wqs [i ](x ) for i in range (self .n_heads )]
259264 new_ks = [self .wks [i ](x ) for i in range (self .n_kv_heads )]
260265 new_vs = [self .wvs [i ](x ) for i in range (self .n_kv_heads )]
266+
267+ if self .use_conv2d :
268+
269+ def from_conv2ds (ts ):
270+ return [
271+ t .reshape (bsz , self .head_dim , seq_len ).transpose (1 , 2 ) for t in ts
272+ ]
273+
274+ new_qs = from_conv2ds (new_qs )
275+ new_ks = from_conv2ds (new_ks )
276+ new_vs = from_conv2ds (new_vs )
277+
261278 new_qs = [self .rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
262279 new_ks = [self .rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
263280 all_ks = []
@@ -282,7 +299,14 @@ def forward(
282299 heads .append (attn @ all_vs [kv_idx ])
283300
284301 y = torch .cat (heads , dim = - 1 )
285- y = self .wo (y )
302+ if self .use_conv2d :
303+ y = (
304+ self .wo (y .reshape (bsz , seq_len , 1 , - 1 ).transpose (1 , 3 ))
305+ .transpose (1 , 3 )
306+ .reshape (bsz , seq_len , - 1 )
307+ )
308+ else :
309+ y = self .wo (y )
286310 return y , {"out_cache_state" : out_cache_state }
287311
288312 def load_weights_from_attention_mha (self , other : AttentionMHA ):
@@ -300,3 +324,44 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
300324 )
301325
302326 self .wo .weight .data .copy_ (other .wo .weight )
327+
328+ def linear_to_conv2d (self ):
329+ def transfer_weight (linear , conv2d ):
330+ conv2d .weight .data .copy_ (linear .weight [:, :, None , None ])
331+ return conv2d
332+
333+ self .wqs = nn .ModuleList (
334+ [
335+ transfer_weight (
336+ linear ,
337+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
338+ )
339+ for linear in self .wqs
340+ ]
341+ )
342+ self .wks = nn .ModuleList (
343+ [
344+ transfer_weight (
345+ linear ,
346+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
347+ )
348+ for linear in self .wks
349+ ]
350+ )
351+ self .wvs = nn .ModuleList (
352+ [
353+ transfer_weight (
354+ linear ,
355+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
356+ )
357+ for linear in self .wvs
358+ ]
359+ )
360+ self .wo = transfer_weight (
361+ self .wo ,
362+ nn .Conv2d (
363+ self .n_heads * self .head_dim , self .dim , 1 , bias = self .attention_qkv_bias
364+ ),
365+ )
366+
367+ self .use_conv2d = True
0 commit comments