@@ -77,14 +77,14 @@ def initialize(self, conf: DictConfig) -> None:
7777 if conf .contigmap .provide_seq is not None :
7878 # this is only used for partial diffusion
7979 assert conf .diffuser .partial_T is not None , "The provide_seq input is specifically for partial diffusion"
80- if conf .scaffoldguided .scaffoldguided :
80+ if conf .scaffoldguided .scaffoldguided_enable :
8181 self .ckpt_path = f'{ model_directory } /InpaintSeq_Fold_ckpt.pt'
8282 else :
8383 self .ckpt_path = f'{ model_directory } /InpaintSeq_ckpt.pt'
84- elif conf .ppi .hotspot_res is not None and conf .scaffoldguided .scaffoldguided is False :
84+ elif conf .ppi .hotspot_res is not None and conf .scaffoldguided .scaffoldguided_enable is False :
8585 # use complex trained model
8686 self .ckpt_path = f'{ model_directory } /Complex_base_ckpt.pt'
87- elif conf .scaffoldguided .scaffoldguided is True :
87+ elif conf .scaffoldguided .scaffoldguided_enable is True :
8888 # use complex and secondary structure-guided model
8989 self .ckpt_path = f'{ model_directory } /Complex_Fold_base_ckpt.pt'
9090 else :
@@ -279,7 +279,6 @@ def sample_init(self, return_forward_trajectory=False):
279279 self .mask_seq = torch .from_numpy (self .contig_map .inpaint_seq )[None ,:]
280280 self .mask_str = torch .from_numpy (self .contig_map .inpaint_str )[None ,:]
281281 self .binderlen = len (self .contig_map .inpaint )
282-
283282 #######################################
284283 ### Resolve cyclic peptide indicies ###
285284 #######################################
@@ -301,7 +300,7 @@ def sample_init(self, return_forward_trajectory=False):
301300 self .cyclic_reses = is_cyclized
302301 else :
303302 self .cyclic_reses = torch .zeros_like (self .mask_str ).bool ().to (self .device ).squeeze ()
304-
303+
305304 ####################
306305 ### Get Hotspots ###
307306 ####################
@@ -681,7 +680,6 @@ def sample_step(self, *, t, x_t, seq_init, final_step):
681680 ####################
682681 ### Forward Pass ###
683682 ####################
684-
685683 with torch .no_grad ():
686684 msa_prev , pair_prev , px0 , state_prev , alpha , logits , plddt = self .model (msa_masked ,
687685 msa_full ,
@@ -771,7 +769,7 @@ def __init__(self, conf: DictConfig):
771769 else :
772770 # initialize BlockAdjacency sampling class
773771 assert all (x is None for x in (conf .contigmap .inpaint_str_helix , conf .contigmap .inpaint_str_strand , conf .contigmap .inpaint_str_loop )), "can't provide scaffold_dir if you're also specifying per-residue ss"
774- self .blockadjacency = iu .BlockAdjacency (conf . scaffoldguided , conf .inference .num_designs )
772+ self .blockadjacency = iu .BlockAdjacency (conf , conf .inference .num_designs )
775773
776774
777775 #################################################
@@ -945,6 +943,27 @@ def sample_init(self):
945943
946944
947945 xT = torch .clone (fa_stack [- 1 ].squeeze ()[:,:14 ,:])
946+
947+ ################################
948+ ### Add to Cyclic_reses init ###
949+ ################################
950+
951+ if self ._conf .inference .cyclic :
952+ if self ._conf .inference .cyc_chains is None :
953+ self .cyclic_reses = ~ self .mask_str .to (self .device ).squeeze ()
954+ else :
955+ assert isinstance (self ._conf .inference .cyc_chains , str ), 'cyc_chains arg must be string'
956+ cyc_chains = self ._conf .inference .cyc_chains
957+ cyc_chains = [i .upper () for i in cyc_chains ]
958+ hal_idx = self .contig_map .hal
959+ is_cyclized = torch .zeros_like (self .mask_str ).bool ().to (self .device ).squeeze ()
960+ for ch in cyc_chains :
961+ ch_mask = torch .tensor ([idx [0 ] == ch for idx in hal_idx ]).bool ()
962+ is_cyclized [ch_mask ] = True
963+ self .cyclic_reses = is_cyclized
964+ else :
965+ self .cyclic_reses = torch .zeros_like (self .mask_str ).bool ().to (self .device ).squeeze ()
966+
948967 return xT , seq_T
949968
950969 def _preprocess (self , seq , xyz_t , t ):
0 commit comments