File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change 206206
207207# constants
208208
209+ # NOTE: for some types of (e.g., AMD ROCm) GPUs, this represents
210+ # the maximum number of elements that can be processed simultaneously
211+ # by backpropagation for a given loss tensor
212+ MAX_ELEMENTS_FOR_BACKPROP = int (2e8 )
213+
209214LinearNoBias = partial (Linear , bias = False )
210215
211216# helper functions
@@ -2891,6 +2896,17 @@ def forward(
28912896 bond_losses = F .mse_loss (denoised_cdist , normalized_cdist , reduction = 'none' )
28922897 bond_losses = bond_losses * loss_weights
28932898
2899+ if atompair_mask .sum () > MAX_ELEMENTS_FOR_BACKPROP :
2900+ # randomly subset the atom pairs to supervise
2901+
2902+ flat_atompair_mask_indices = torch .arange (atompair_mask .numel (), device = self .device )[atompair_mask .view (- 1 )]
2903+ num_true_atompairs = flat_atompair_mask_indices .size (0 )
2904+
2905+ num_atompairs_to_ignore = num_true_atompairs - MAX_ELEMENTS_FOR_BACKPROP
2906+ ignored_atompair_indices = flat_atompair_mask_indices [torch .randperm (num_true_atompairs )[:num_atompairs_to_ignore ]]
2907+
2908+ atompair_mask .view (- 1 )[ignored_atompair_indices ] = False
2909+
28942910 bond_loss = bond_losses [atompair_mask ].mean ()
28952911
28962912 total_loss = total_loss + bond_loss
You can’t perform that action at this time.
0 commit comments