Skip to content

Commit 2331ed5

Browse files
committed
cleanup
1 parent 9907a3d commit 2331ed5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4292,7 +4292,7 @@ class ScoreDetails(NamedTuple):
42924292
best_gpde_index: int
42934293
best_lddt_index: int
42944294
score: Float[' b']
4295-
scored_samples: ScoredSample
4295+
scored_samples: List[ScoredSample]
42964296

42974297
class ComputeModelSelectionScore(Module):
42984298
"""Compute model selection score."""
@@ -4894,11 +4894,11 @@ def compute_model_selection_score(
48944894

48954895
# rank by batch-averaged gPDE
48964896

4897-
best_gpde_index = torch.stack(all_gpde).mean(dim = -1).topk(1).indices.item()
4897+
best_gpde_index = torch.stack(all_gpde).mean(dim = -1).argmax().item()
48984898

48994899
# rank by batch-averaged lDDT
49004900

4901-
best_lddt_index = torch.stack(all_weighted_lddt).mean(dim = -1).topk(1).indices.item()
4901+
best_lddt_index = torch.stack(all_weighted_lddt).mean(dim = -1).argmax().item()
49024902

49034903
# some weighted score
49044904

0 commit comments

Comments
 (0)