|
14 | 14 | from rfdiffusion import util |
15 | 15 | from hydra.core.hydra_config import HydraConfig |
16 | 16 | import os |
| 17 | +import string |
17 | 18 |
|
18 | 19 | from rfdiffusion.model_input_logger import pickle_function_call |
19 | 20 | import sys |
@@ -144,13 +145,14 @@ def initialize(self, conf: DictConfig) -> None: |
144 | 145 | self.symmetry = None |
145 | 146 |
|
146 | 147 | self.allatom = ComputeAllAtomCoords().to(self.device) |
147 | | - |
| 148 | + |
148 | 149 | if self.inf_conf.input_pdb is None: |
149 | 150 | # set default pdb |
150 | 151 | script_dir=os.path.dirname(os.path.realpath(__file__)) |
151 | 152 | self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') |
152 | 153 | self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) |
153 | 154 | self.chain_idx = None |
| 155 | + self.idx_pdb = None |
154 | 156 |
|
155 | 157 | ############################## |
156 | 158 | ### Handle Partial Noising ### |
@@ -330,8 +332,42 @@ def sample_init(self, return_forward_trajectory=False): |
330 | 332 | contig_map=self.contig_map |
331 | 333 |
|
332 | 334 | self.diffusion_mask = self.mask_str |
333 | | - self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(L_mapped)] |
334 | | - |
| 335 | + length_bound = self.contig_map.sampled_mask_length_bound.copy() |
| 336 | + |
| 337 | + first_res = 0 |
| 338 | + self.chain_idx = [] |
| 339 | + self.idx_pdb = [] |
| 340 | + all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref} |
| 341 | + available_chains = sorted(list(set(string.ascii_letters) - all_chains)) |
| 342 | + |
| 343 | + # Iterate over each chain |
| 344 | + for last_res in length_bound: |
| 345 | + chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} |
| 346 | + # If we are designing this chain, it will have a '-' in the contig map |
| 347 | + # Renumber this chain from 1 |
| 348 | + if "_" in chain_ids: |
| 349 | + self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)] |
| 350 | + chain_ids = chain_ids - {"_"} |
| 351 | + # If there are no fixed residues that have a chain id, pick the first available letter |
| 352 | + if not chain_ids: |
| 353 | + if not available_chains: |
| 354 | + raise ValueError(f"No available chains! You are trying to design a new chain, and you have " |
| 355 | + f"already used all upper- and lower-case chain ids (up to 52 chains): " |
| 356 | + f"{','.join(all_chains)}.") |
| 357 | + chain_id = available_chains[0] |
| 358 | + available_chains.remove(chain_id) |
| 359 | + # Otherwise, use the chain of the fixed (motif) residues |
| 360 | + else: |
| 361 | + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" |
| 362 | + chain_id = list(chain_ids)[0] |
| 363 | + self.chain_idx += [chain_id] * (last_res - first_res) |
| 364 | + # If this is a fixed chain, maintain the chain and residue numbering |
| 365 | + else: |
| 366 | + self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]] |
| 367 | + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" |
| 368 | + self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) |
| 369 | + first_res = last_res |
| 370 | + |
335 | 371 | #################################### |
336 | 372 | ### Generate initial coordinates ### |
337 | 373 | #################################### |
|
0 commit comments