Skip to content

Commit 1b51b32

Browse files
committed
move a common operation for constituting pairwise masks from single mask into fn
1 parent 5d0aae6 commit 1b51b32

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,16 @@ def lens_to_mask(
262262
arange = torch.arange(max_len, device = device)
263263
return einx.less('m, ... -> ... m', arange, lens)
264264

265+
@typecheck
266+
def to_pairwise_mask(
267+
mask_i: Bool['... n'],
268+
mask_j: Bool['... n'] | None = None
269+
) -> Bool['... n n']:
270+
271+
mask_j = default(mask_j, mask_i)
272+
assert mask_i.shape == mask_j.shape
273+
return einx.logical_and('... i, ... j -> ... i j', mask_i, mask_j)
274+
265275
@typecheck
266276
def mean_pool_with_lens(
267277
feats: Float['b m d'],
@@ -598,7 +608,8 @@ def forward(
598608
) -> Float['b n n d']:
599609

600610
if exists(mask):
601-
mask = einx.logical_and('b i, b j -> b i j 1', mask, mask)
611+
mask = to_pairwise_mask(mask)
612+
mask = rearrange(mask, '... -> ... 1')
602613

603614
x = self.norm(x)
604615

@@ -867,8 +878,8 @@ def forward(
867878
# masking for pairwise repr
868879

869880
if exists(mask):
870-
mask = einx.logical_and('b i , b j -> b i j 1', mask, mask)
871-
outer_product_mean = outer_product_mean * mask
881+
mask = to_pairwise_mask(mask)
882+
outer_product_mean = einx.multiply('b i j d, b i j', outer_product_mean, mask.float())
872883

873884
pairwise_repr = self.to_pairwise_repr(outer_product_mean)
874885
return pairwise_repr
@@ -2338,7 +2349,7 @@ def forward(
23382349
bond_loss = self.zero
23392350

23402351
if add_bond_loss:
2341-
atompair_mask = einx.logical_and('b i, b j -> b i j', atom_mask, atom_mask)
2352+
atompair_mask = to_pairwise_mask(atom_mask)
23422353

23432354
denoised_cdist = torch.cdist(denoised_atom_pos, denoised_atom_pos, p = 2)
23442355
normalized_cdist = torch.cdist(atom_pos_ground_truth, atom_pos_ground_truth, p = 2)
@@ -2420,7 +2431,7 @@ def forward(
24202431

24212432
# Restrict to bespoke inclusion radius
24222433
is_nucleotide = is_dna | is_rna
2423-
is_nucleotide_pair = einx.logical_and('b i, b j -> b i j', is_nucleotide, is_nucleotide)
2434+
is_nucleotide_pair = to_pairwise_mask(is_nucleotide)
24242435

24252436
inclusion_radius = torch.where(
24262437
is_nucleotide_pair,
@@ -2433,7 +2444,7 @@ def forward(
24332444

24342445
# Take into account variable lengthed atoms in batch
24352446
if exists(coords_mask):
2436-
paired_coords_mask = einx.logical_and('b i, b j -> b i j', coords_mask, coords_mask)
2447+
paired_coords_mask = to_pairwise_mask(coords_mask)
24372448
mask = mask & paired_coords_mask
24382449

24392450
# Calculate masked averaging
@@ -3236,8 +3247,7 @@ def compute_pde(
32363247

32373248
pde = einsum(probs, bin_centers, 'b i j pde, pde -> b i j ')
32383249

3239-
mask = einx.logical_and(
3240-
'b i, b j -> b i j', tok_repr_atm_mask, tok_repr_atm_mask)
3250+
mask = to_pairwise_mask(tok_repr_atm_mask)
32413251

32423252
pde = pde * mask
32433253
return pde
@@ -3610,8 +3620,7 @@ def compute_gpde(
36103620
contact_mask, dist_probs, 0.
36113621
).sum(dim=-1)
36123622

3613-
mask = einx.logical_and(
3614-
'b i, b j -> b i j', tok_repr_atm_mask, tok_repr_atm_mask)
3623+
mask = to_pairwise_mask(tok_repr_atm_mask)
36153624
contact_prob = contact_prob * mask
36163625

36173626
# Section 5.7 equation 16
@@ -3652,7 +3661,7 @@ def compute_lddt(
36523661

36533662
# Restrict to bespoke inclusion radius
36543663
is_nucleotide = is_dna | is_rna
3655-
is_nucleotide_pair = einx.logical_and('b i, b j -> b i j', is_nucleotide, is_nucleotide)
3664+
is_nucleotide_pair = to_pairwise_mask(is_nucleotide)
36563665

36573666
inclusion_radius = torch.where(
36583667
is_nucleotide_pair,
@@ -3665,7 +3674,7 @@ def compute_lddt(
36653674

36663675
# Take into account variable lengthed atoms in batch
36673676
if exists(coords_mask):
3668-
paired_coords_mask = einx.logical_and('b i, b j -> b i j', coords_mask, coords_mask)
3677+
paired_coords_mask = to_pairwise_mask(coords_mask)
36693678
mask = mask & paired_coords_mask
36703679

36713680
mask = mask * pairwise_mask
@@ -3703,9 +3712,7 @@ def compute_chain_pair_lddt(
37033712

37043713
is_dna = is_molecule_types[..., IS_DNA_INDEX]
37053714
is_rna = is_molecule_types[..., IS_RNA_INDEX]
3706-
pairwise_mask = einx.logical_and(
3707-
'b m, b n -> b m n', asym_mask_a, asym_mask_b,
3708-
)
3715+
pairwise_mask = to_pairwise_mask(asym_mask_a)
37093716

37103717
lddt = self.compute_lddt(
37113718
pred_coords, true_coords, is_dna, is_rna, pairwise_mask, coords_mask
@@ -4342,7 +4349,7 @@ def forward(
43424349
# not to ligands + metal ions
43434350

43444351
is_chained_biomol = is_molecule_types[..., IS_BIOMOLECULE_INDICES].any(dim = -1) # first three types are chained biomolecules (protein, rna, dna)
4345-
paired_is_chained_biomol = einx.logical_and('b i, b j -> b i j', is_chained_biomol, is_chained_biomol)
4352+
paired_is_chained_biomol = to_pairwise_mask(is_chained_biomol)
43464353

43474354
relative_position_encoding = einx.where(
43484355
'b i j, b i j d, -> b i j d',
@@ -4374,7 +4381,7 @@ def forward(
43744381
# molecule mask and pairwise mask
43754382

43764383
mask = molecule_atom_lens > 0
4377-
pairwise_mask = einx.logical_and('b i, b j -> b i j', mask, mask)
4384+
pairwise_mask = to_pairwise_mask(mask)
43784385

43794386
# prepare mask for msa module and template embedder
43804387
# which is equivalent to the `is_protein` of the `is_molecular_types` input
@@ -4525,7 +4532,7 @@ def forward(
45254532

45264533
# account for representative distogram atom missing from residue (-1 set on distogram_atom_indices field)
45274534

4528-
valid_distogram_mask = einx.logical_and('b i, b j -> b i j', valid_distogram_mask, valid_distogram_mask)
4535+
valid_distogram_mask = to_pairwise_mask(valid_distogram_mask)
45294536
distance_labels.masked_fill_(~valid_distogram_mask, ignore)
45304537

45314538
if exists(distance_labels):
@@ -4680,7 +4687,7 @@ def forward(
46804687
if self.confidence_head.atom_resolution:
46814688
label_mask = atom_mask
46824689

4683-
label_pairwise_mask = einx.logical_and('... i, ... j -> ... i j', label_mask, label_mask)
4690+
label_pairwise_mask = to_pairwise_mask(label_mask)
46844691

46854692
# cross entropy losses
46864693

0 commit comments

Comments
 (0)