@@ -352,7 +352,7 @@ def forward(
352352 x_r , x_i = x [..., ::2 ], x [..., 1 ::2 ]
353353 x_out_r = x_r * freqs_cos - x_i * freqs_sin
354354 x_out_i = x_r * freqs_sin + x_i * freqs_cos
355- x_out = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
355+ x_out = torch .stack ([x_out_r , x_out_i ], dim = - 1 ). flatten ( 2 )
356356 return x_out
357357
358358
@@ -378,6 +378,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
378378 self .inv_scale = 1.0 / (float (self .head_dim ) ** 0.5 )
379379 self .attention_qkv_bias = config .attention_qkv_bias
380380 self .use_qk_norm = config .use_qk_norm
381+ self .qk_norm_before_rope = config .qk_norm_before_rope
381382 self .use_conv2d = False
382383
383384 self .wqs = nn .ModuleList (
@@ -449,12 +450,17 @@ def from_conv2ds(ts):
449450 new_ks = from_conv2ds (new_ks )
450451 new_vs = from_conv2ds (new_vs )
451452
452- if self .use_qk_norm :
453+ if self .use_qk_norm and self . qk_norm_before_rope :
453454 new_qs = [self .q_norm (q ) for q in new_qs ]
454455 new_ks = [self .k_norm (k ) for k in new_ks ]
455456
456457 new_qs = [self .rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
457458 new_ks = [self .rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
459+
460+ if self .use_qk_norm and not self .qk_norm_before_rope :
461+ new_qs = [self .q_norm (q ) for q in new_qs ]
462+ new_ks = [self .k_norm (k ) for k in new_ks ]
463+
458464 all_ks = []
459465 all_vs = []
460466 for i in range (self .n_kv_heads ):
@@ -505,6 +511,7 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
505511
506512 if other .use_qk_norm :
507513 self .use_qk_norm = True
514+ self .qk_norm_before_rope = other .qk_norm_before_rope
508515 self .q_norm = torch .nn .RMSNorm (other .q_norm_fn .dim , other .q_norm_fn .eps )
509516 self .q_norm .load_state_dict (other .q_norm_fn .state_dict ())
510517 self .k_norm = torch .nn .RMSNorm (other .k_norm_fn .dim , other .k_norm_fn .eps )
0 commit comments