@@ -104,7 +104,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
104104
105105 self .scale = float (self .head_dim ) ** 0.5
106106
107- if hasattr (config , "enable_r3" ) and config . enable_r3 :
107+ if getattr (config , "enable_r3" , False ) :
108108 self .register_buffer (
109109 "r3_weight" ,
110110 torch .tensor (
@@ -223,18 +223,20 @@ def forward_sha( # noqa: C901
223223 if self .use_qk_norm and self .qk_norm_before_rope :
224224 q [i ] = self .q_norm_fn (q [i ])
225225 q [i ] = self .apply_rope_emb (q [i ], freqs_cos , freqs_sin )
226- if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
227- q [i ] = torch .matmul (q [i ], self .r3_weight )
228226 if self .use_qk_norm and not self .qk_norm_before_rope :
229227 q [i ] = self .q_norm_fn (q [i ])
228+ if getattr (self .config , "enable_r3" , False ):
229+ q [i ] = torch .matmul (q [i ], self .r3_weight )
230+
230231 for i in range (len (k )):
231232 if self .use_qk_norm and self .qk_norm_before_rope :
232233 k [i ] = self .k_norm_fn (k [i ])
233- k [i ] = self .apply_rope_emb (k [i ], freqs_cos , freqs_sin ).transpose (1 , 2 )
234- if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
235- k [i ] = torch .matmul (k [i ], self .r3_weight )
234+ k [i ] = self .apply_rope_emb (k [i ], freqs_cos , freqs_sin )
236235 if self .use_qk_norm and not self .qk_norm_before_rope :
237236 k [i ] = self .k_norm_fn (k [i ])
237+ if getattr (self .config , "enable_r3" , False ):
238+ k [i ] = torch .matmul (k [i ], self .r3_weight )
239+ k [i ] = k [i ].transpose (1 , 2 )
238240
239241 output_y = []
240242 kh , vh = [], []
0 commit comments