@@ -2187,7 +2187,7 @@ def __init__(
21872187 self ,
21882188 * ,
21892189 dim_atom_inputs ,
2190- dim_additional_residue_feats ,
2190+ dim_additional_residue_feats = 10 ,
21912191 atoms_per_window = 27 ,
21922192 dim_atom = 128 ,
21932193 dim_atompair = 16 ,
@@ -2558,6 +2558,14 @@ def __init__(
25582558 ** relative_position_encoding_kwargs
25592559 )
25602560
2561+ # token bonds
2562+ # Algorithm 1 - line 5
2563+
2564+ self .token_bond_to_pairwise_feat = nn .Sequential (
2565+ Rearrange ('... -> ... 1' ),
2566+ LinearNoBias (1 , dim_pairwise )
2567+ )
2568+
25612569 # templates
25622570
25632571 self .template_embedder = TemplateEmbedder (
@@ -2654,7 +2662,8 @@ def forward(
26542662 atom_inputs : Float ['b m dai' ],
26552663 atom_mask : Bool ['b m' ],
26562664 atompair_feats : Float ['b m m dap' ],
2657- additional_residue_feats : Float ['b n rf' ],
2665+ additional_residue_feats : Float ['b n 10' ],
2666+ token_bond : Bool ['b n n' ] | None = None ,
26582667 msa : Float ['b s n d' ] | None = None ,
26592668 msa_mask : Bool ['b s' ] | None = None ,
26602669 templates : Float ['b t n n dt' ] | None = None ,
@@ -2673,7 +2682,13 @@ def forward(
26732682 return_loss_breakdown = False
26742683 ) -> Float ['b m 3' ] | Float ['' ] | Tuple [Float ['' ], LossBreakdown ]:
26752684
2685+ # get atom sequence length and residue sequence length
2686+
26762687 w = self .atoms_per_window
2688+ atom_seq_len = atom_inputs .shape [- 2 ]
2689+
2690+ assert divisible_by (atom_seq_len , w )
2691+ seq_len = atom_inputs .shape [- 2 ] // w
26772692
26782693 # embed inputs
26792694
@@ -2698,6 +2713,24 @@ def forward(
26982713
26992714 pairwise_init = pairwise_init + relative_position_encoding
27002715
2716+ # token bond features
2717+
2718+ if exists (token_bond ):
2719+ # well do some precautionary standardization
2720+ # (1) mask out diagonal - token to itself does not count as a bond
2721+ # (2) symmetrize, in case it is not already symmetrical (could also throw an error)
2722+
2723+ token_bond = token_bond | rearrange (token_bond , 'b i j -> b j i' )
2724+ diagonal = torch .eye (seq_len , device = self .device , dtype = torch .bool )
2725+ token_bond .masked_fill_ (diagonal , False )
2726+ else :
2727+ seq_arange = torch .arange (seq_len , device = self .device )
2728+ token_bond = einx .subtract ('i, j -> i j' , seq_arange , seq_arange ).abs () == 1
2729+
2730+ token_bond_feats = self .token_bond_to_pairwise_feat (token_bond .float ())
2731+
2732+ pairwise_init = pairwise_init + token_bond_feats
2733+
27012734 # pairwise mask
27022735
27032736 mask = reduce (atom_mask , 'b (n w) -> b n' , w = w , reduction = 'any' )
0 commit comments