Skip to content

Commit 91dabdc

Browse files
authored
Make align_weights in alphafold3.py perfectly match Equation 4 in the AF3 supplement (#21)
1 parent f746f7b commit 91dabdc

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2021,8 +2021,9 @@ def forward(
20212021

20222022
# section 3.7.1 equation 4
20232023

2024-
align_weights = torch.where(atom_is_dna | atom_is_rna, nucleotide_loss_weight, align_weights)
2025-
align_weights = torch.where(atom_is_ligand, ligand_loss_weight, align_weights)
2024+
# upweighting of nucleotide and ligand atoms is additive per equation 4
2025+
align_weights = torch.where(atom_is_dna | atom_is_rna, 1 + nucleotide_loss_weight, align_weights)
2026+
align_weights = torch.where(atom_is_ligand, 1 + ligand_loss_weight, align_weights)
20262027

20272028
# section 3.7.1 equation 2 - weighted rigid aligned ground truth
20282029

0 commit comments

Comments
 (0)