Skip to content

Commit e220924

Browse files
authored
Retain chain and residue numbering in RFdiffusion (#348)
A number of issues (e.g., #103 , #171 , #312 , #315 ) have mentioned that RFdiffusion will change the chain IDs and residue numbering of the input structure. The designed chain ends up as chain "A", and the fixed chain(s) end up as chain "B". The numbering is also reset to start at 1. This can be particularly problematic in cases where comparisons to structures are needed, as well as multi-chain situations where all of the chains get fused. Inspired by @GCS-ZHN 's comment and solution referenced in Issue #103 , I've modified the code to maintain chain and residue numbering. In particular: Chains that are not "designable" will retain their original chain ID letters and residue numbers. Chains that are partially fixed (e.g., motif re-scaffolding) will retain their original chain ID letters. Residues will be re-numbered from 1 to length of chain. (It was not clear to me what the "correct" behaviour of chain residue numbering should be, given that the length of the chain and the position of any fixed residues might change.) Chains that are being fully generated de novo will be assigned the first available chain ID in the alphabet not used by any other chain. Residues will be numbered from 1 to length of chain.
2 parents 735de5e + d32205a commit e220924

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

rfdiffusion/contigs.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
self.inpaint,
8181
self.inpaint_hal,
8282
self.inpaint_rf,
83+
self.sampled_mask_length_bound,
8384
) = self.expand_sampled_mask()
8485
self.ref = self.inpaint + self.receptor
8586
self.hal = self.inpaint_hal + self.receptor_hal
@@ -241,6 +242,8 @@ def expand_sampled_mask(self):
241242
inpaint_chain_idx = -1
242243
receptor_chain_break = []
243244
inpaint_chain_break = []
245+
_receptor_mask_length_bound = []
246+
_inpaint_mask_length_bound = []
244247
for con in self.sampled_mask:
245248
if (
246249
all([i[0].isalpha() for i in con.split("/")[:-1]])
@@ -286,6 +289,7 @@ def expand_sampled_mask(self):
286289
receptor_chain_break.append(
287290
(receptor_idx - 1, 200)
288291
) # 200 aa chain break
292+
_receptor_mask_length_bound.append(len(receptor))
289293
else:
290294
inpaint_chain_idx += 1
291295
for subcon in con.split("/"):
@@ -320,6 +324,7 @@ def expand_sampled_mask(self):
320324
)
321325
inpaint_idx += int(subcon.split("-")[0])
322326
inpaint_chain_break.append((inpaint_idx - 1, 200))
327+
_inpaint_mask_length_bound.append(len(inpaint))
323328

324329
if self.topo is True or inpaint_hal == []:
325330
receptor_hal = [(i[0], i[1]) for i in receptor_hal]
@@ -335,14 +340,21 @@ def expand_sampled_mask(self):
335340
inpaint_rf[ch_break[0] :] += ch_break[1]
336341
for ch_break in receptor_chain_break[:-1]:
337342
receptor_rf[ch_break[0] :] += ch_break[1]
338-
343+
sampled_mask_length_bound = []
344+
sampled_mask_length_bound.extend(_inpaint_mask_length_bound)
345+
if _inpaint_mask_length_bound:
346+
inpaint_last_bound = _inpaint_mask_length_bound[-1]
347+
else:
348+
inpaint_last_bound = 0
349+
sampled_mask_length_bound.extend(map(lambda x: x + inpaint_last_bound, _receptor_mask_length_bound))
339350
return (
340351
receptor,
341352
receptor_hal,
342353
receptor_rf.tolist(),
343354
inpaint,
344355
inpaint_hal,
345356
inpaint_rf.tolist(),
357+
sampled_mask_length_bound
346358
)
347359

348360
def get_inpaint_seq_str(self, inpaint_s, ss=False):

rfdiffusion/inference/model_runners.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from rfdiffusion import util
1515
from hydra.core.hydra_config import HydraConfig
1616
import os
17+
import string
1718

1819
from rfdiffusion.model_input_logger import pickle_function_call
1920
import sys
@@ -144,13 +145,14 @@ def initialize(self, conf: DictConfig) -> None:
144145
self.symmetry = None
145146

146147
self.allatom = ComputeAllAtomCoords().to(self.device)
147-
148+
148149
if self.inf_conf.input_pdb is None:
149150
# set default pdb
150151
script_dir=os.path.dirname(os.path.realpath(__file__))
151152
self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb')
152153
self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False)
153154
self.chain_idx = None
155+
self.idx_pdb = None
154156

155157
##############################
156158
### Handle Partial Noising ###
@@ -331,8 +333,42 @@ def sample_init(self, return_forward_trajectory=False):
331333
contig_map=self.contig_map
332334

333335
self.diffusion_mask = self.mask_str
334-
self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(L_mapped)]
335-
336+
length_bound = self.contig_map.sampled_mask_length_bound.copy()
337+
338+
first_res = 0
339+
self.chain_idx = []
340+
self.idx_pdb = []
341+
all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref}
342+
available_chains = sorted(list(set(string.ascii_letters) - all_chains))
343+
344+
# Iterate over each chain
345+
for last_res in length_bound:
346+
chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]}
347+
# If we are designing this chain, it will have a '-' in the contig map
348+
# Renumber this chain from 1
349+
if "_" in chain_ids:
350+
self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)]
351+
chain_ids = chain_ids - {"_"}
352+
# If there are no fixed residues that have a chain id, pick the first available letter
353+
if not chain_ids:
354+
if not available_chains:
355+
raise ValueError(f"No available chains! You are trying to design a new chain, and you have "
356+
f"already used all upper- and lower-case chain ids (up to 52 chains): "
357+
f"{','.join(all_chains)}.")
358+
chain_id = available_chains[0]
359+
available_chains.remove(chain_id)
360+
# Otherwise, use the chain of the fixed (motif) residues
361+
else:
362+
assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}"
363+
chain_id = list(chain_ids)[0]
364+
self.chain_idx += [chain_id] * (last_res - first_res)
365+
# If this is a fixed chain, maintain the chain and residue numbering
366+
else:
367+
self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]]
368+
assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}"
369+
self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res)
370+
first_res = last_res
371+
336372
####################################
337373
### Generate initial coordinates ###
338374
####################################

scripts/run_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def main(conf: HydraConfig) -> None:
141141
sampler.binderlen,
142142
chain_idx=sampler.chain_idx,
143143
bfacts=bfacts,
144+
idx_pdb=sampler.idx_pdb
144145
)
145146

146147
# run metadata

0 commit comments

Comments
 (0)