Skip to content

Commit 2474c7f

Browse files
authored
Update alphafold3.py (#271)
1 parent 7df7521 commit 2474c7f

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@
206206

207207
# constants
208208

209+
# NOTE: for some types of (e.g., AMD ROCm) GPUs, this represents
210+
# the maximum number of elements that can be processed simultaneously
211+
# by backpropagation for a given loss tensor
212+
MAX_ELEMENTS_FOR_BACKPROP = int(2e8)
213+
209214
LinearNoBias = partial(Linear, bias = False)
210215

211216
# helper functions
@@ -2891,6 +2896,17 @@ def forward(
28912896
bond_losses = F.mse_loss(denoised_cdist, normalized_cdist, reduction = 'none')
28922897
bond_losses = bond_losses * loss_weights
28932898

2899+
if atompair_mask.sum() > MAX_ELEMENTS_FOR_BACKPROP:
2900+
# randomly subset the atom pairs to supervise
2901+
2902+
flat_atompair_mask_indices = torch.arange(atompair_mask.numel(), device=self.device)[atompair_mask.view(-1)]
2903+
num_true_atompairs = flat_atompair_mask_indices.size(0)
2904+
2905+
num_atompairs_to_ignore = num_true_atompairs - MAX_ELEMENTS_FOR_BACKPROP
2906+
ignored_atompair_indices = flat_atompair_mask_indices[torch.randperm(num_true_atompairs)[:num_atompairs_to_ignore]]
2907+
2908+
atompair_mask.view(-1)[ignored_atompair_indices] = False
2909+
28942910
bond_loss = bond_losses[atompair_mask].mean()
28952911

28962912
total_loss = total_loss + bond_loss

0 commit comments

Comments
 (0)