@@ -114,15 +114,30 @@ def update(
114114 return all_data , (out_k_cache , out_v_cache )
115115
116116
117- def _apply_rotary_embedding (
118- x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
119- ) -> torch .Tensor :
120- x_r , x_i = x [..., ::2 ], x [..., 1 ::2 ]
121- x_out_r = x_r * freqs_cos - x_i * freqs_sin
122- x_out_i = x_r * freqs_sin + x_i * freqs_cos
117+ class _Rope (nn .Module ):
118+ def __init__ (self , use_hf_rope ):
119+ super ().__init__ ()
120+ self .use_hf_rope = use_hf_rope
121+
122+ def forward (
123+ self , x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
124+ ) -> torch .Tensor :
125+ if self .use_hf_rope :
126+ if len (freqs_cos .shape ) == 2 :
127+ freqs_cos = freqs_cos .unsqueeze (0 )
128+ if len (freqs_sin .shape ) == 2 :
129+ freqs_sin = freqs_sin .unsqueeze (0 )
130+ x1 = x [..., : x .shape [- 1 ] // 2 ]
131+ x2 = x [..., x .shape [- 1 ] // 2 :]
132+ x_rotated = torch .cat ((- x2 , x1 ), dim = - 1 )
133+ return x * freqs_cos + x_rotated * freqs_sin
134+ else :
135+ x_r , x_i = x [..., ::2 ], x [..., 1 ::2 ]
136+ x_out_r = x_r * freqs_cos - x_i * freqs_sin
137+ x_out_i = x_r * freqs_sin + x_i * freqs_cos
123138
124- x_out = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
125- return x_out
139+ x_out = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
140+ return x_out
126141
127142
128143@register_attention ("static" )
@@ -172,6 +187,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
172187 [StaticVCache (layer_id , i ) for i in range (self .n_kv_heads )]
173188 )
174189 self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
190+ self .rope = _Rope (rope .params .use_hf_rope )
175191
176192 def forward (
177193 self ,
@@ -191,8 +207,8 @@ def forward(
191207 new_qs = [self .wqs [i ](x ) for i in range (self .n_heads )]
192208 new_ks = [self .wks [i ](x ) for i in range (self .n_kv_heads )]
193209 new_vs = [self .wvs [i ](x ) for i in range (self .n_kv_heads )]
194- new_qs = [_apply_rotary_embedding (q , freqs_cos , freqs_sin ) for q in new_qs ]
195- new_ks = [_apply_rotary_embedding (k , freqs_cos , freqs_sin ) for k in new_ks ]
210+ new_qs = [self . rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
211+ new_ks = [self . rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
196212
197213 all_ks = []
198214 all_vs = []
@@ -211,7 +227,7 @@ def forward(
211227 kv_idx = i // self .n_heads_per_kv_group
212228 attn = new_qs [i ] @ all_ks [kv_idx ].transpose (- 2 , - 1 )
213229 attn = attn * self .inv_scale
214- attn = attn + mask # pyre-ignore
230+ attn = attn + mask
215231 attn = F .softmax (attn , dim = - 1 )
216232 heads .append (attn @ all_vs [kv_idx ])
217233
0 commit comments