Skip to content

Commit 6cd31c7

Browse files
authored
Update alphafold3.py (#219)
1 parent 7532e79 commit 6cd31c7

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3283,7 +3283,7 @@ def get_least_asym_entity_or_longest_length(
32833283

32843284
# Calculate entity length
32853285
entity_mask = batch["entity_id"] == entity_id
3286-
entity_length[int(entity_id)] = entity_mask.sum().item()
3286+
entity_length[int(entity_id)] = entity_mask.sum(-1).mode().values.item()
32873287

32883288
min_asym_count = min(entity_asym_count.values())
32893289
least_asym_entities = [
@@ -3313,6 +3313,18 @@ def get_least_asym_entity_or_longest_length(
33133313
if asym_id in input_asym_id
33143314
]
33153315

3316+
# Since the entity ID to asym ID mapping is many-to-many, we need to select only
3317+
# prediction asym IDs with equal length w.r.t. the sampled ground truth asym ID
3318+
anchor_gt_asym_id_length = (
3319+
(batch["asym_id"] == anchor_gt_asym_id).sum(-1).mode().values.item()
3320+
)
3321+
anchor_pred_asym_ids = [
3322+
asym_id
3323+
for asym_id in anchor_pred_asym_ids
3324+
if (batch["asym_id"] == asym_id).sum(-1).mode().values.item()
3325+
== anchor_gt_asym_id_length
3326+
]
3327+
33163328
# Remap `asym_id` values to remove any gaps in the ground truth asym IDs,
33173329
# but leave the prediction asym IDs as is since they are used for masking
33183330
sorted_asym_ids = sorted(all_asym_ids)

0 commit comments

Comments
 (0)