Skip to content

Commit f746f7b

Browse files
committed
complete register tokens
1 parent 936210c commit f746f7b

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,7 @@ def __init__(
13231323
dim_pairwise = 128,
13241324
attn_window_size = None,
13251325
attn_pair_bias_kwargs: dict = dict(),
1326+
num_register_tokens = 0,
13261327
serial = False
13271328
):
13281329
super().__init__()
@@ -1365,6 +1366,12 @@ def __init__(
13651366

13661367
self.serial = serial
13671368

1369+
self.has_registers = num_register_tokens > 0
1370+
self.num_registers = num_register_tokens
1371+
1372+
if self.has_registers:
1373+
self.registers = nn.Parameter(torch.zeros(num_register_tokens, dim))
1374+
13681375
@typecheck
13691376
def forward(
13701377
self,
@@ -1376,6 +1383,21 @@ def forward(
13761383
):
13771384
serial = self.serial
13781385

1386+
# register tokens
1387+
1388+
if self.has_registers:
1389+
num_registers = self.num_registers
1390+
registers = repeat(self.registers, 'r d -> b r d', b = noised_repr.shape[0])
1391+
noised_repr, registers_ps = pack((registers, noised_repr), 'b * d')
1392+
1393+
single_repr = F.pad(single_repr, (0, 0, num_registers, 0), value = 0.)
1394+
pairwise_repr = F.pad(pairwise_repr, (0, 0, num_registers, 0, num_registers, 0), value = 0.)
1395+
1396+
if exists(mask):
1397+
mask = F.pad(mask, (num_registers, 0), value = True)
1398+
1399+
# main transformer
1400+
13791401
for attn, transition in self.layers:
13801402

13811403
attn_out = attn(
@@ -1398,6 +1420,11 @@ def forward(
13981420

13991421
noised_repr = noised_repr + ff_out
14001422

1423+
# splice out registers
1424+
1425+
if self.has_registers:
1426+
_, noised_repr = unpack(noised_repr, registers_ps, 'b * d')
1427+
14011428
return noised_repr
14021429

14031430
class AtomToTokenPooler(Module):
@@ -1487,7 +1514,10 @@ def __init__(
14871514
token_transformer_heads = 16,
14881515
atom_decoder_depth = 3,
14891516
atom_decoder_heads = 4,
1490-
serial = False
1517+
serial = False,
1518+
atom_encoder_kwargs: dict = dict(),
1519+
atom_decoder_kwargs: dict = dict(),
1520+
token_transformer_kwargs: dict = dict()
14911521
):
14921522
super().__init__()
14931523

@@ -1543,7 +1573,8 @@ def __init__(
15431573
attn_window_size = atoms_per_window,
15441574
depth = atom_encoder_depth,
15451575
heads = atom_encoder_heads,
1546-
serial = serial
1576+
serial = serial,
1577+
**atom_encoder_kwargs
15471578
)
15481579

15491580
self.atom_feats_to_pooled_token = AtomToTokenPooler(
@@ -1565,7 +1596,8 @@ def __init__(
15651596
dim_pairwise = dim_pairwise,
15661597
depth = token_transformer_depth,
15671598
heads = token_transformer_heads,
1568-
serial = serial
1599+
serial = serial,
1600+
**token_transformer_kwargs
15691601
)
15701602

15711603
self.attended_token_norm = nn.LayerNorm(dim_token)
@@ -1581,7 +1613,8 @@ def __init__(
15811613
attn_window_size = atoms_per_window,
15821614
depth = atom_decoder_depth,
15831615
heads = atom_decoder_heads,
1584-
serial = serial
1616+
serial = serial,
1617+
**atom_decoder_kwargs
15851618
)
15861619

15871620
self.atom_feat_to_atom_pos_update = nn.Sequential(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.40"
3+
version = "0.0.41"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,10 @@ def test_diffusion_module():
233233
dim_pairwise_rel_pos_feats = 12,
234234
atom_encoder_depth = 1,
235235
atom_decoder_depth = 1,
236-
token_transformer_depth = 1
236+
token_transformer_depth = 1,
237+
token_transformer_kwargs = dict(
238+
num_register_tokens = 2
239+
)
237240
)
238241

239242
atom_pos_update = diffusion_module(

0 commit comments

Comments
 (0)