Skip to content

Commit 27a7818

Browse files
committed
they removed all the bias from the triangle multiplicative modules
1 parent 1e3c221 commit 27a7818

File tree

2 files changed

+12
-21
lines changed

2 files changed

+12
-21
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -433,18 +433,14 @@ def __init__(
433433
dim_hidden = default(dim_hidden, dim)
434434
self.norm = nn.LayerNorm(dim)
435435

436-
self.left_proj = Linear(dim, dim_hidden)
437-
self.right_proj = Linear(dim, dim_hidden)
438-
439-
self.left_gate = Linear(dim, dim_hidden)
440-
self.right_gate = Linear(dim, dim_hidden)
441-
self.out_gate = Linear(dim, dim_hidden)
436+
self.left_right_proj = nn.Sequential(
437+
LinearNoBias(dim, dim_hidden * 4),
438+
nn.GLU(dim = -1)
439+
)
442440

443-
# initialize all gating to be identity
441+
self.left_right_gate = LinearNoBias(dim, dim_hidden * 2)
444442

445-
for gate in (self.left_gate, self.right_gate, self.out_gate):
446-
nn.init.constant_(gate.weight, 0.)
447-
nn.init.constant_(gate.bias, 1.)
443+
self.out_gate = LinearNoBias(dim, dim_hidden)
448444

449445
if mix == 'outgoing':
450446
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
@@ -454,7 +450,7 @@ def __init__(
454450
self.to_out_norm = nn.LayerNorm(dim_hidden)
455451

456452
self.to_out = Sequential(
457-
Linear(dim_hidden, dim),
453+
LinearNoBias(dim_hidden, dim),
458454
Dropout(dropout, dropout_type = dropout_type)
459455
)
460456

@@ -470,24 +466,19 @@ def forward(
470466

471467
x = self.norm(x)
472468

473-
left = self.left_proj(x)
474-
right = self.right_proj(x)
469+
left, right = self.left_right_proj(x).chunk(2, dim = -1)
475470

476471
if exists(mask):
477472
left = left * mask
478473
right = right * mask
479474

480-
left_gate = self.left_gate(x).sigmoid()
481-
right_gate = self.right_gate(x).sigmoid()
482-
out_gate = self.out_gate(x).sigmoid()
483-
484-
left = left * left_gate
485-
right = right * right_gate
486-
487475
out = einsum(left, right, self.mix_einsum_eq)
488476

489477
out = self.to_out_norm(out)
478+
479+
out_gate = self.out_gate(x).sigmoid()
490480
out = out * out_gate
481+
491482
return self.to_out(out)
492483

493484
# there are two types of attention in this paper, triangle and attention-pair-bias

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.37"
3+
version = "0.0.38"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)