Skip to content

Commit ffe6a89

Browse files
authored
Remove zero init and constant biasing for to_gate (#304)
* Update attention.py * Update attention.py
1 parent 41aa91c commit ffe6a89

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

alphafold3_pytorch/attention.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from beartype.typing import NamedTuple, Tuple
3+
from functools import partial
34

45
import torch
56
from torch import nn, Tensor
@@ -18,6 +19,10 @@
1819
typecheck
1920
)
2021

22+
# alias
23+
24+
LinearNoBias = partial(nn.Linear, bias = False)
25+
2126
# helpers
2227

2328
def exists(val):
@@ -178,7 +183,6 @@ def __init__(
178183
num_memory_kv: int = 0,
179184
enable_attn_softclamp = False,
180185
attn_softclamp_value = 50.,
181-
init_gate_bias = -2.,
182186
softmax_full_precision = False
183187
):
184188
super().__init__()
@@ -209,8 +213,8 @@ def __init__(
209213
self.merge_heads = Rearrange('b h n d -> b n (h d)')
210214

211215
self.to_q = nn.Linear(dim, dim_inner, bias = query_bias)
212-
self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False)
213-
self.to_out = nn.Linear(dim_inner, dim, bias = False)
216+
self.to_kv = LinearNoBias(dim, dim_inner * 2)
217+
self.to_out = LinearNoBias(dim_inner, dim)
214218

215219
self.memory_kv = None
216220

@@ -224,11 +228,7 @@ def __init__(
224228
self.to_gates = None
225229

226230
if gate_output:
227-
gate_linear = nn.Linear(dim, dim_inner)
228-
nn.init.zeros_(gate_linear.weight)
229-
nn.init.constant_(gate_linear.bias, init_gate_bias)
230-
231-
self.to_gates = gate_linear
231+
self.to_gates = nn.Sequential(LinearNoBias(dim, dim_inner), nn.Sigmoid())
232232

233233
@typecheck
234234
def forward(
@@ -266,7 +266,7 @@ def forward(
266266

267267
if exists(self.to_gates):
268268
gates = self.to_gates(seq)
269-
out = out * gates.sigmoid()
269+
out = out * gates
270270

271271
# combine heads
272272

0 commit comments

Comments
 (0)