@@ -226,20 +226,21 @@ def forward(
226226 v = self .v_proj (hidden_states )
227227 gk = self .gk_proj (hidden_states )
228228
229- if self .feature_map_fn is not None :
230- q , k = map (self .feature_map_fn , (q , k ))
231229 q = rearrange (q , '... (h d) -> ... h d' , d = self .head_k_dim )
232230 if self .num_kv_groups > 1 :
233231 k , gk = (repeat (x , '... (h d) -> ... (h g) d' , g = self .num_kv_groups , d = self .head_k_dim ) for x in (k , gk ))
234232 v = repeat (v , '... (h d) -> ... (h g) d' , g = self .num_kv_groups , d = self .head_v_dim )
235233 else :
236234 k , gk = (rearrange (x , '... (h d) -> ... h d' , d = self .head_k_dim ) for x in (k , gk ))
237235 v = rearrange (v , '... (h d) -> ... h d' , d = self .head_v_dim )
238- gk = F .logsigmoid (gk ) / self .gate_logit_normalizer
239236
237+ gk = F .logsigmoid (gk ) / self .gate_logit_normalizer
240238 if self .clamp_min is not None :
241239 gk = torch .clamp_min (gk , self .clamp_min )
242240
241+ if self .feature_map_fn is not None :
242+ q , k = map (self .feature_map_fn , (q , k ))
243+
243244 recurrent_state = last_state ['recurrent_state' ] if last_state is not None else None
244245 if mode == 'fused_recurrent' :
245246 o , recurrent_state = fused_recurrent_gla (
0 commit comments