Skip to content

Commit 7c30fee

Browse files
Fixed issues with designing in scaffoldguided mode, for example: design_ppi_scaffold.sh. The solutions to issues 272 and 273 did not fully address the issue.
1 parent fa34014 commit 7c30fee

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

config/inference/base.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ logging:
126126
inputs: False
127127

128128
scaffoldguided:
129-
scaffoldguided: False
129+
scaffoldguided_enable: False
130130
target_pdb: False
131131
target_path: null
132132
scaffold_list: null

examples/design_ppi_scaffolded.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
# We then provide a path to a directory of different scaffolds (we've provided some for you to use, from Cao et al., 2022)
88
# We generate 10 designs, and reduce the noise added during inference to 0 (which improves the quality of designs)
99

10-
../scripts/run_inference.py scaffoldguided.target_path=input_pdbs/insulin_target.pdb inference.output_prefix=example_outputs/design_ppi_scaffolded scaffoldguided.scaffoldguided=True 'ppi.hotspot_res=[A59,A83,A91]' scaffoldguided.target_pdb=True scaffoldguided.target_ss=target_folds/insulin_target_ss.pt scaffoldguided.target_adj=target_folds/insulin_target_adj.pt scaffoldguided.scaffold_dir=./ppi_scaffolds/ inference.num_designs=10 denoiser.noise_scale_ca=0 denoiser.noise_scale_frame=0
10+
../scripts/run_inference.py scaffoldguided.target_path=input_pdbs/insulin_target.pdb inference.output_prefix=example_outputs/design_ppi_scaffolded scaffoldguided.scaffoldguided_enable=True 'ppi.hotspot_res=[A59,A83,A91]' scaffoldguided.target_pdb=True scaffoldguided.target_ss=target_folds/insulin_target_ss.pt scaffoldguided.target_adj=target_folds/insulin_target_adj.pt scaffoldguided.scaffold_dir=./ppi_scaffolds/ inference.num_designs=10 denoiser.noise_scale_ca=0 denoiser.noise_scale_frame=0

rfdiffusion/inference/model_runners.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ def initialize(self, conf: DictConfig) -> None:
7777
if conf.contigmap.provide_seq is not None:
7878
# this is only used for partial diffusion
7979
assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion"
80-
if conf.scaffoldguided.scaffoldguided:
80+
if conf.scaffoldguided.scaffoldguided_enable:
8181
self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt'
8282
else:
8383
self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt'
84-
elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False:
84+
elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided_enable is False:
8585
# use complex trained model
8686
self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt'
87-
elif conf.scaffoldguided.scaffoldguided is True:
87+
elif conf.scaffoldguided.scaffoldguided_enable is True:
8888
# use complex and secondary structure-guided model
8989
self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt'
9090
else:
@@ -279,7 +279,6 @@ def sample_init(self, return_forward_trajectory=False):
279279
self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:]
280280
self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:]
281281
self.binderlen = len(self.contig_map.inpaint)
282-
283282
#######################################
284283
### Resolve cyclic peptide indicies ###
285284
#######################################
@@ -301,7 +300,7 @@ def sample_init(self, return_forward_trajectory=False):
301300
self.cyclic_reses = is_cyclized
302301
else:
303302
self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze()
304-
303+
305304
####################
306305
### Get Hotspots ###
307306
####################
@@ -681,7 +680,6 @@ def sample_step(self, *, t, x_t, seq_init, final_step):
681680
####################
682681
### Forward Pass ###
683682
####################
684-
685683
with torch.no_grad():
686684
msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked,
687685
msa_full,
@@ -771,7 +769,7 @@ def __init__(self, conf: DictConfig):
771769
else:
772770
# initialize BlockAdjacency sampling class
773771
assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss"
774-
self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs)
772+
self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs)
775773

776774

777775
#################################################
@@ -945,6 +943,27 @@ def sample_init(self):
945943

946944

947945
xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:])
946+
947+
################################
948+
### Add to Cyclic_reses init ###
949+
################################
950+
951+
if self._conf.inference.cyclic:
952+
if self._conf.inference.cyc_chains is None:
953+
self.cyclic_reses = ~self.mask_str.to(self.device).squeeze()
954+
else:
955+
assert isinstance(self._conf.inference.cyc_chains, str), 'cyc_chains arg must be string'
956+
cyc_chains = self._conf.inference.cyc_chains
957+
cyc_chains = [i.upper() for i in cyc_chains]
958+
hal_idx = self.contig_map.hal
959+
is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze()
960+
for ch in cyc_chains:
961+
ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool()
962+
is_cyclized[ch_mask] = True
963+
self.cyclic_reses = is_cyclized
964+
else:
965+
self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze()
966+
948967
return xT, seq_T
949968

950969
def _preprocess(self, seq, xyz_t, t):

rfdiffusion/inference/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def get_next_pose(
502502

503503

504504
def sampler_selector(conf: DictConfig):
505-
if conf.scaffoldguided.scaffoldguided:
505+
if conf.scaffoldguided.scaffoldguided_enable:
506506
sampler = model_runners.ScaffoldedSampler(conf)
507507
else:
508508
if conf.inference.model_runner == "default":
@@ -1012,4 +1012,4 @@ def ss_from_contig(ss_masks: dict):
10121012
for idx, mask in enumerate([ss_masks['helix'],ss_masks['strand'], ss_masks['loop']]):
10131013
ss[mask,idx] = 1
10141014
ss[mask, 3] = 0 # remove the mask token
1015-
return ss
1015+
return ss

0 commit comments

Comments
 (0)