Skip to content

Commit 909fc01

Browse files
committed
Maintains the input chain ids in RFdiffusion output
The output was previously renumbering all of the chains, making comparisons to the input structures and handling of multi-chain inputs challenging. This commit maintains the input chain ids in the output.
1 parent b44206a commit 909fc01

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

rfdiffusion/contigs.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
self.inpaint,
8181
self.inpaint_hal,
8282
self.inpaint_rf,
83+
self.sampled_mask_length_bound,
8384
) = self.expand_sampled_mask()
8485
self.ref = self.inpaint + self.receptor
8586
self.hal = self.inpaint_hal + self.receptor_hal
@@ -241,6 +242,8 @@ def expand_sampled_mask(self):
241242
inpaint_chain_idx = -1
242243
receptor_chain_break = []
243244
inpaint_chain_break = []
245+
_receptor_mask_length_bound = []
246+
_inpaint_mask_length_bound = []
244247
for con in self.sampled_mask:
245248
if (
246249
all([i[0].isalpha() for i in con.split("/")[:-1]])
@@ -286,6 +289,7 @@ def expand_sampled_mask(self):
286289
receptor_chain_break.append(
287290
(receptor_idx - 1, 200)
288291
) # 200 aa chain break
292+
_receptor_mask_length_bound.append(len(receptor))
289293
else:
290294
inpaint_chain_idx += 1
291295
for subcon in con.split("/"):
@@ -320,6 +324,7 @@ def expand_sampled_mask(self):
320324
)
321325
inpaint_idx += int(subcon.split("-")[0])
322326
inpaint_chain_break.append((inpaint_idx - 1, 200))
327+
_inpaint_mask_length_bound.append(len(inpaint))
323328

324329
if self.topo is True or inpaint_hal == []:
325330
receptor_hal = [(i[0], i[1]) for i in receptor_hal]
@@ -335,14 +340,21 @@ def expand_sampled_mask(self):
335340
inpaint_rf[ch_break[0] :] += ch_break[1]
336341
for ch_break in receptor_chain_break[:-1]:
337342
receptor_rf[ch_break[0] :] += ch_break[1]
338-
343+
sampled_mask_length_bound = []
344+
sampled_mask_length_bound.extend(_inpaint_mask_length_bound)
345+
if _inpaint_mask_length_bound:
346+
inpaint_last_bound = _inpaint_mask_length_bound[-1]
347+
else:
348+
inpaint_last_bound = 0
349+
sampled_mask_length_bound.extend(map(lambda x: x + inpaint_last_bound, _receptor_mask_length_bound))
339350
return (
340351
receptor,
341352
receptor_hal,
342353
receptor_rf.tolist(),
343354
inpaint,
344355
inpaint_hal,
345356
inpaint_rf.tolist(),
357+
sampled_mask_length_bound
346358
)
347359

348360
def get_inpaint_seq_str(self, inpaint_s, ss=False):

rfdiffusion/inference/model_runners.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,16 @@ def sample_init(self, return_forward_trajectory=False):
309309
contig_map=self.contig_map
310310

311311
self.diffusion_mask = self.mask_str
312-
self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(L_mapped)]
313-
312+
length_bound = self.contig_map.sampled_mask_length_bound.copy()
313+
314+
first_res = 0
315+
self.chain_idx = []
316+
for last_res in length_bound:
317+
chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} - {"_"}
318+
assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}"
319+
self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res)
320+
first_res = last_res
321+
314322
####################################
315323
### Generate initial coordinates ###
316324
####################################

0 commit comments

Comments
 (0)