Skip to content

Commit 63e270f

Browse files
committed
For fixed chains, retain residue numbering
For chains that are completely fixed, retain the residue numbering from the input rather than renumbering. For chains that are partially or fully designed by RFdiffusion, it isn't clear to me what the 'correct' behaviour should be, so these chains will be re-numbered starting at residue 1.
1 parent 909fc01 commit 63e270f

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

rfdiffusion/inference/model_runners.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from rfdiffusion import util
1515
from hydra.core.hydra_config import HydraConfig
1616
import os
17+
import string
1718

1819
from rfdiffusion.model_input_logger import pickle_function_call
1920
import sys
@@ -144,13 +145,14 @@ def initialize(self, conf: DictConfig) -> None:
144145
self.symmetry = None
145146

146147
self.allatom = ComputeAllAtomCoords().to(self.device)
147-
148+
148149
if self.inf_conf.input_pdb is None:
149150
# set default pdb
150151
script_dir=os.path.dirname(os.path.realpath(__file__))
151152
self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb')
152153
self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False)
153154
self.chain_idx = None
155+
self.idx_pdb = None
154156

155157
##############################
156158
### Handle Partial Noising ###
@@ -313,10 +315,31 @@ def sample_init(self, return_forward_trajectory=False):
313315

314316
first_res = 0
315317
self.chain_idx = []
318+
self.idx_pdb = []
319+
all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref}
320+
available_chains = sorted(list(set(string.ascii_uppercase) - all_chains))
321+
# Iterate over each chain
316322
for last_res in length_bound:
317-
chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} - {"_"}
318-
assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}"
319-
self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res)
323+
chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]}
324+
# If we are designing this chain, it will have a '-' in the contig map
325+
# Renumber this chain from 1
326+
if "_" in chain_ids:
327+
self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)]
328+
chain_ids = chain_ids - {"_"}
329+
# If there are no fixed residues that have a chain id, pick the first available letter
330+
if not chain_ids:
331+
chain_id = available_chains[0]
332+
available_chains.remove(chain_id)
333+
# Otherwise, use the chain of the fixed (motif) residues
334+
else:
335+
assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}"
336+
chain_id = list(chain_ids)[0]
337+
self.chain_idx += [chain_id] * (last_res - first_res)
338+
# If this is a fixed chain, maintain the chain and residue numbering
339+
else:
340+
self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]]
341+
assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}"
342+
self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res)
320343
first_res = last_res
321344

322345
####################################

scripts/run_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def main(conf: HydraConfig) -> None:
141141
sampler.binderlen,
142142
chain_idx=sampler.chain_idx,
143143
bfacts=bfacts,
144+
idx_pdb=sampler.idx_pdb
144145
)
145146

146147
# run metadata

0 commit comments

Comments
 (0)