@@ -1323,6 +1323,7 @@ def __init__(
13231323 dim_pairwise = 128 ,
13241324 attn_window_size = None ,
13251325 attn_pair_bias_kwargs : dict = dict (),
1326+ num_register_tokens = 0 ,
13261327 serial = False
13271328 ):
13281329 super ().__init__ ()
@@ -1365,6 +1366,12 @@ def __init__(
13651366
13661367 self .serial = serial
13671368
1369+ self .has_registers = num_register_tokens > 0
1370+ self .num_registers = num_register_tokens
1371+
1372+ if self .has_registers :
1373+ self .registers = nn .Parameter (torch .zeros (num_register_tokens , dim ))
1374+
13681375 @typecheck
13691376 def forward (
13701377 self ,
@@ -1376,6 +1383,21 @@ def forward(
13761383 ):
13771384 serial = self .serial
13781385
1386+ # register tokens
1387+
1388+ if self .has_registers :
1389+ num_registers = self .num_registers
1390+ registers = repeat (self .registers , 'r d -> b r d' , b = noised_repr .shape [0 ])
1391+ noised_repr , registers_ps = pack ((registers , noised_repr ), 'b * d' )
1392+
1393+ single_repr = F .pad (single_repr , (0 , 0 , num_registers , 0 ), value = 0. )
1394+ pairwise_repr = F .pad (pairwise_repr , (0 , 0 , num_registers , 0 , num_registers , 0 ), value = 0. )
1395+
1396+ if exists (mask ):
1397+ mask = F .pad (mask , (num_registers , 0 ), value = True )
1398+
1399+ # main transformer
1400+
13791401 for attn , transition in self .layers :
13801402
13811403 attn_out = attn (
@@ -1398,6 +1420,11 @@ def forward(
13981420
13991421 noised_repr = noised_repr + ff_out
14001422
1423+ # splice out registers
1424+
1425+ if self .has_registers :
1426+ _ , noised_repr = unpack (noised_repr , registers_ps , 'b * d' )
1427+
14011428 return noised_repr
14021429
14031430class AtomToTokenPooler (Module ):
@@ -1487,7 +1514,10 @@ def __init__(
14871514 token_transformer_heads = 16 ,
14881515 atom_decoder_depth = 3 ,
14891516 atom_decoder_heads = 4 ,
1490- serial = False
1517+ serial = False ,
1518+ atom_encoder_kwargs : dict = dict (),
1519+ atom_decoder_kwargs : dict = dict (),
1520+ token_transformer_kwargs : dict = dict ()
14911521 ):
14921522 super ().__init__ ()
14931523
@@ -1543,7 +1573,8 @@ def __init__(
15431573 attn_window_size = atoms_per_window ,
15441574 depth = atom_encoder_depth ,
15451575 heads = atom_encoder_heads ,
1546- serial = serial
1576+ serial = serial ,
1577+ ** atom_encoder_kwargs
15471578 )
15481579
15491580 self .atom_feats_to_pooled_token = AtomToTokenPooler (
@@ -1565,7 +1596,8 @@ def __init__(
15651596 dim_pairwise = dim_pairwise ,
15661597 depth = token_transformer_depth ,
15671598 heads = token_transformer_heads ,
1568- serial = serial
1599+ serial = serial ,
1600+ ** token_transformer_kwargs
15691601 )
15701602
15711603 self .attended_token_norm = nn .LayerNorm (dim_token )
@@ -1581,7 +1613,8 @@ def __init__(
15811613 attn_window_size = atoms_per_window ,
15821614 depth = atom_decoder_depth ,
15831615 heads = atom_decoder_heads ,
1584- serial = serial
1616+ serial = serial ,
1617+ ** atom_decoder_kwargs
15851618 )
15861619
15871620 self .atom_feat_to_atom_pos_update = nn .Sequential (
0 commit comments