Skip to content

Commit 2e73362

Browse files
authored
[GLA] change the order of reshape and apply feature map. Fixing #606
1 parent 26be14f commit 2e73362

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

fla/layers/gla.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,20 +226,21 @@ def forward(
226226
v = self.v_proj(hidden_states)
227227
gk = self.gk_proj(hidden_states)
228228

229-
if self.feature_map_fn is not None:
230-
q, k = map(self.feature_map_fn, (q, k))
231229
q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
232230
if self.num_kv_groups > 1:
233231
k, gk = (repeat(x, '... (h d) -> ... (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
234232
v = repeat(v, '... (h d) -> ... (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
235233
else:
236234
k, gk = (rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim) for x in (k, gk))
237235
v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
238-
gk = F.logsigmoid(gk) / self.gate_logit_normalizer
239236

237+
gk = F.logsigmoid(gk) / self.gate_logit_normalizer
240238
if self.clamp_min is not None:
241239
gk = torch.clamp_min(gk, self.clamp_min)
242240

241+
if self.feature_map_fn is not None:
242+
q, k = map(self.feature_map_fn, (q, k))
243+
243244
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
244245
if mode == 'fused_recurrent':
245246
o, recurrent_state = fused_recurrent_gla(

0 commit comments

Comments
 (0)