|
14 | 14 | from rfdiffusion import util |
15 | 15 | from hydra.core.hydra_config import HydraConfig |
16 | 16 | import os |
| 17 | +import string |
17 | 18 |
|
18 | 19 | from rfdiffusion.model_input_logger import pickle_function_call |
19 | 20 | import sys |
@@ -144,13 +145,14 @@ def initialize(self, conf: DictConfig) -> None: |
144 | 145 | self.symmetry = None |
145 | 146 |
|
146 | 147 | self.allatom = ComputeAllAtomCoords().to(self.device) |
147 | | - |
| 148 | + |
148 | 149 | if self.inf_conf.input_pdb is None: |
149 | 150 | # set default pdb |
150 | 151 | script_dir=os.path.dirname(os.path.realpath(__file__)) |
151 | 152 | self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') |
152 | 153 | self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) |
153 | 154 | self.chain_idx = None |
| 155 | + self.idx_pdb = None |
154 | 156 |
|
155 | 157 | ############################## |
156 | 158 | ### Handle Partial Noising ### |
@@ -313,10 +315,31 @@ def sample_init(self, return_forward_trajectory=False): |
313 | 315 |
|
314 | 316 | first_res = 0 |
315 | 317 | self.chain_idx = [] |
| 318 | + self.idx_pdb = [] |
| 319 | + all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref} |
| 320 | + available_chains = sorted(list(set(string.ascii_uppercase) - all_chains)) |
| 321 | + # Iterate over each chain |
316 | 322 | for last_res in length_bound: |
317 | | - chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} - {"_"} |
318 | | - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" |
319 | | - self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) |
| 323 | + chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} |
| 324 | + # If we are designing this chain, it will have a '-' in the contig map |
| 325 | + # Renumber this chain from 1 |
| 326 | + if "_" in chain_ids: |
| 327 | + self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)] |
| 328 | + chain_ids = chain_ids - {"_"} |
| 329 | + # If there are no fixed residues that have a chain id, pick the first available letter |
| 330 | + if not chain_ids: |
| 331 | + chain_id = available_chains[0] |
| 332 | + available_chains.remove(chain_id) |
| 333 | + # Otherwise, use the chain of the fixed (motif) residues |
| 334 | + else: |
| 335 | + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" |
| 336 | + chain_id = list(chain_ids)[0] |
| 337 | + self.chain_idx += [chain_id] * (last_res - first_res) |
| 338 | + # If this is a fixed chain, maintain the chain and residue numbering |
| 339 | + else: |
| 340 | + self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]] |
| 341 | + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" |
| 342 | + self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) |
320 | 343 | first_res = last_res |
321 | 344 |
|
322 | 345 | #################################### |
|
0 commit comments