Skip to content

Commit b819fd5

Browse files
Joseph Watsonnrbennet
authored andcommitted
Added secondary structure specification
1 parent 820bfdf commit b819fd5

File tree

4 files changed

+89
-15
lines changed

4 files changed

+89
-15
lines changed

config/inference/base.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ contigmap:
2626
contigs: null
2727
inpaint_seq: null
2828
inpaint_str: null
29+
inpaint_str_helix: null
30+
inpaint_str_strand: null
31+
inpaint_str_loop: null
2932
provide_seq: null
3033
length: null
3134

rfdiffusion/contigs.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(
2727
inpaint_str_tensor=None,
2828
topo=False,
2929
provide_seq=None,
30+
inpaint_str_strand=None,
31+
inpaint_str_helix=None,
32+
inpaint_str_loop=None
3033
):
3134
# sanity checks
3235
if contigs is None and ref_idx is None:
@@ -48,12 +51,15 @@ def __init__(
4851
self.ref_idx = ref_idx
4952
self.hal_idx = hal_idx
5053
self.idx_rf = idx_rf
51-
self.inpaint_seq = (
52-
"/".join(inpaint_seq).split("/") if inpaint_seq is not None else None
53-
)
54-
self.inpaint_str = (
55-
"/".join(inpaint_str).split("/") if inpaint_str is not None else None
56-
)
54+
55+
parse_inpaint = lambda x: "/".join(x).split("/") if x is not None else None
56+
self.inpaint_seq = parse_inpaint(inpaint_seq)
57+
self.inpaint_str = parse_inpaint(inpaint_str)
58+
59+
self.inpaint_str_helix=parse_inpaint(inpaint_str_helix)
60+
self.inpaint_str_strand=parse_inpaint(inpaint_str_strand)
61+
self.inpaint_str_loop=parse_inpaint(inpaint_str_loop)
62+
5763
self.inpaint_seq_tensor = inpaint_seq_tensor
5864
self.inpaint_str_tensor = inpaint_str_tensor
5965
self.parsed_pdb = parsed_pdb
@@ -125,6 +131,39 @@ def __init__(
125131
else:
126132
self.inpaint_seq[int(i)] = True
127133

134+
"""
135+
We have now added the ability to specify the secondary structure of provided sequence.
136+
This is described in Liu et al., 2024
137+
https://www.biorxiv.org/content/10.1101/2024.07.16.603789v1
138+
This is for the case that e.g. you have a sequence, but don't know the structure (like an IDR), but
139+
want to specify the secondary structure of this sequence.
140+
Making this compatible with the contigmap object allows all the variable length stuff to be handled.
141+
142+
The logic:
143+
Secondary structure is provided at the command line, using the following three flags:
144+
inpaint_str_helix
145+
inpaint_str_strand
146+
inpaint_str_loop
147+
148+
These are so named because they pertain to the region of the input pdb that you have applied inpaint_str to
149+
In other words, any part of the input protein you are masking the structure of, you can specify the secondary structure of.
150+
However, you can't specify the secondary structure of a region you're not applying inpaint_str to, as this doesn't make sense.
151+
"""
152+
153+
if any(x is not None for x in (inpaint_str_helix, inpaint_str_strand, inpaint_str_loop)):
154+
self.ss_spec={}
155+
order=['helix','strand','loop']
156+
for idx, i in enumerate([inpaint_str_helix, inpaint_str_strand, inpaint_str_loop]):
157+
if i is not None:
158+
self.ss_spec[order[idx]] = ~self.get_inpaint_seq_str(i, ss=True)
159+
else:
160+
self.ss_spec[order[idx]] = np.zeros(len(self.inpaint_seq), dtype=bool)
161+
# some sensible checks
162+
for key, mask in self.ss_spec.items():
163+
assert sum(mask*self.inpaint_str) == 0, f"You've specified {key} residues that are not structure-masked with inpaint_str. This doesn't really make sense."
164+
stack=np.vstack([mask for mask in self.ss_spec.values()])
165+
assert np.max(np.sum(stack, axis=0)) == 1, "You've given multiple secondary structure assignations to an input residue. This doesn't make sense."
166+
128167
def get_sampled_mask(self):
129168
"""
130169
Function to get a sampled mask from a contig.
@@ -306,11 +345,14 @@ def expand_sampled_mask(self):
306345
inpaint_rf.tolist(),
307346
)
308347

309-
def get_inpaint_seq_str(self, inpaint_s):
310-
"""
348+
def get_inpaint_seq_str(self, inpaint_s, ss=False):
349+
'''
311350
function to generate inpaint_str or inpaint_seq masks specific to this contig
312-
"""
313-
s_mask = np.copy(self.mask_1d)
351+
'''
352+
if not ss:
353+
s_mask = np.copy(self.mask_1d)
354+
else:
355+
s_mask= np.ones(len(self.mask_1d), dtype=bool)
314356
inpaint_s_list = []
315357
for i in inpaint_s:
316358
if "-" in i:

rfdiffusion/inference/model_runners.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,15 @@ def __init__(self, conf: DictConfig):
741741
"""
742742
super().__init__(conf)
743743
# initialize BlockAdjacency sampling class
744-
self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs)
744+
if conf.scaffoldguided.scaffold_dir is None:
745+
assert any(x is not None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop))
746+
if conf.contigmap.inpaint_str_loop is not None:
747+
assert conf.scaffoldguided.mask_loops == False, "You shouldn't be masking loops if you're specifying loop secondary structure"
748+
else:
749+
# initialize BlockAdjacency sampling class
750+
assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss"
751+
self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs)
752+
745753

746754
#################################################
747755
### Initialize target, if doing binder design ###
@@ -773,8 +781,11 @@ def sample_init(self):
773781
##########################
774782
### Process Fold Input ###
775783
##########################
776-
self.L, self.ss, self.adj = self.blockadjacency.get_scaffold()
777-
self.adj = nn.one_hot(self.adj.long(), num_classes=3)
784+
if hasattr(self, 'blockadjacency'):
785+
self.L, self.ss, self.adj = self.blockadjacency.get_scaffold()
786+
self.adj = nn.one_hot(self.adj.long(), num_classes=3)
787+
else:
788+
self.L=100 # shim. Get's overwritten
778789

779790
##############################
780791
### Auto-contig generation ###
@@ -848,7 +859,7 @@ def sample_init(self):
848859
seq_T=torch.full((L_mapped,),21)
849860
seq_T[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0]
850861
seq_T[~self.mask_seq.squeeze()] = 21
851-
assert L_mapped==self.adj.shape[0]
862+
852863
diffusion_mask = self.mask_str
853864
self.diffusion_mask = diffusion_mask
854865

@@ -857,7 +868,13 @@ def sample_init(self):
857868
xT = get_init_xyz(xT).squeeze()
858869
atom_mask = torch.full((L_mapped, 27), False)
859870
atom_mask[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0]
860-
871+
872+
if hasattr(self.contig_map, 'ss_spec'):
873+
self.adj=torch.full((L_mapped, L_mapped),2) # masked
874+
self.adj=nn.one_hot(self.adj.long(), num_classes=3)
875+
self.ss=iu.ss_from_contig(self.contig_map.ss_spec)
876+
assert L_mapped==self.adj.shape[0]
877+
861878
####################
862879
### Get hotspots ###
863880
####################

rfdiffusion/inference/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,3 +1001,15 @@ def contig_crop(self, contig_crop, residue_offset=200) -> None:
10011001

10021002
def get_target(self):
10031003
return self.pdb
1004+
1005+
def ss_from_contig(ss_masks: dict):
1006+
"""
1007+
Function for taking 1D masks for each of the ss types, and outputting a secondary structure input
1008+
"""
1009+
L=len(ss_masks['helix'])
1010+
ss=torch.zeros((L, 4)).long()
1011+
ss[:,3] = 1 #mask
1012+
for idx, mask in enumerate([ss_masks['helix'],ss_masks['strand'], ss_masks['loop']]):
1013+
ss[mask,idx] = 1
1014+
ss[mask, 3] = 0 # remove the mask token
1015+
return ss

0 commit comments

Comments
 (0)