11from __future__ import annotations
22from beartype .typing import NamedTuple , Tuple
3+ from functools import partial
34
45import torch
56from torch import nn , Tensor
1819 typecheck
1920)
2021
22+ # alias
23+
24+ LinearNoBias = partial (nn .Linear , bias = False )
25+
2126# helpers
2227
2328def exists (val ):
@@ -178,7 +183,6 @@ def __init__(
178183 num_memory_kv : int = 0 ,
179184 enable_attn_softclamp = False ,
180185 attn_softclamp_value = 50. ,
181- init_gate_bias = - 2. ,
182186 softmax_full_precision = False
183187 ):
184188 super ().__init__ ()
@@ -209,8 +213,8 @@ def __init__(
209213 self .merge_heads = Rearrange ('b h n d -> b n (h d)' )
210214
211215 self .to_q = nn .Linear (dim , dim_inner , bias = query_bias )
212- self .to_kv = nn . Linear (dim , dim_inner * 2 , bias = False )
213- self .to_out = nn . Linear (dim_inner , dim , bias = False )
216+ self .to_kv = LinearNoBias (dim , dim_inner * 2 )
217+ self .to_out = LinearNoBias (dim_inner , dim )
214218
215219 self .memory_kv = None
216220
@@ -224,11 +228,7 @@ def __init__(
224228 self .to_gates = None
225229
226230 if gate_output :
227- gate_linear = nn .Linear (dim , dim_inner )
228- nn .init .zeros_ (gate_linear .weight )
229- nn .init .constant_ (gate_linear .bias , init_gate_bias )
230-
231- self .to_gates = gate_linear
231+ self .to_gates = nn .Sequential (LinearNoBias (dim , dim_inner ), nn .Sigmoid ())
232232
233233 @typecheck
234234 def forward (
@@ -266,7 +266,7 @@ def forward(
266266
267267 if exists (self .to_gates ):
268268 gates = self .to_gates (seq )
269- out = out * gates . sigmoid ()
269+ out = out * gates
270270
271271 # combine heads
272272
0 commit comments