@@ -262,6 +262,16 @@ def lens_to_mask(
262262 arange = torch .arange (max_len , device = device )
263263 return einx .less ('m, ... -> ... m' , arange , lens )
264264
265+ @typecheck
266+ def to_pairwise_mask (
267+ mask_i : Bool ['... n' ],
268+ mask_j : Bool ['... n' ] | None = None
269+ ) -> Bool ['... n n' ]:
270+
271+ mask_j = default (mask_j , mask_i )
272+ assert mask_i .shape == mask_j .shape
273+ return einx .logical_and ('... i, ... j -> ... i j' , mask_i , mask_j )
274+
265275@typecheck
266276def mean_pool_with_lens (
267277 feats : Float ['b m d' ],
@@ -598,7 +608,8 @@ def forward(
598608 ) -> Float ['b n n d' ]:
599609
600610 if exists (mask ):
601- mask = einx .logical_and ('b i, b j -> b i j 1' , mask , mask )
611+ mask = to_pairwise_mask (mask )
612+ mask = rearrange (mask , '... -> ... 1' )
602613
603614 x = self .norm (x )
604615
@@ -867,8 +878,8 @@ def forward(
867878 # masking for pairwise repr
868879
869880 if exists (mask ):
870- mask = einx . logical_and ( 'b i , b j -> b i j 1' , mask , mask )
871- outer_product_mean = outer_product_mean * mask
881+ mask = to_pairwise_mask ( mask )
882+ outer_product_mean = einx . multiply ( 'b i j d, b i j' , outer_product_mean , mask . float ())
872883
873884 pairwise_repr = self .to_pairwise_repr (outer_product_mean )
874885 return pairwise_repr
@@ -2338,7 +2349,7 @@ def forward(
23382349 bond_loss = self .zero
23392350
23402351 if add_bond_loss :
2341- atompair_mask = einx . logical_and ( 'b i, b j -> b i j' , atom_mask , atom_mask )
2352+ atompair_mask = to_pairwise_mask ( atom_mask )
23422353
23432354 denoised_cdist = torch .cdist (denoised_atom_pos , denoised_atom_pos , p = 2 )
23442355 normalized_cdist = torch .cdist (atom_pos_ground_truth , atom_pos_ground_truth , p = 2 )
@@ -2420,7 +2431,7 @@ def forward(
24202431
24212432 # Restrict to bespoke inclusion radius
24222433 is_nucleotide = is_dna | is_rna
2423- is_nucleotide_pair = einx . logical_and ( 'b i, b j -> b i j' , is_nucleotide , is_nucleotide )
2434+ is_nucleotide_pair = to_pairwise_mask ( is_nucleotide )
24242435
24252436 inclusion_radius = torch .where (
24262437 is_nucleotide_pair ,
@@ -2433,7 +2444,7 @@ def forward(
24332444
24342445 # Take into account variable lengthed atoms in batch
24352446 if exists (coords_mask ):
2436- paired_coords_mask = einx . logical_and ( 'b i, b j -> b i j' , coords_mask , coords_mask )
2447+ paired_coords_mask = to_pairwise_mask ( coords_mask )
24372448 mask = mask & paired_coords_mask
24382449
24392450 # Calculate masked averaging
@@ -3236,8 +3247,7 @@ def compute_pde(
32363247
32373248 pde = einsum (probs , bin_centers , 'b i j pde, pde -> b i j ' )
32383249
3239- mask = einx .logical_and (
3240- 'b i, b j -> b i j' , tok_repr_atm_mask , tok_repr_atm_mask )
3250+ mask = to_pairwise_mask (tok_repr_atm_mask )
32413251
32423252 pde = pde * mask
32433253 return pde
@@ -3610,8 +3620,7 @@ def compute_gpde(
36103620 contact_mask , dist_probs , 0.
36113621 ).sum (dim = - 1 )
36123622
3613- mask = einx .logical_and (
3614- 'b i, b j -> b i j' , tok_repr_atm_mask , tok_repr_atm_mask )
3623+ mask = to_pairwise_mask (tok_repr_atm_mask )
36153624 contact_prob = contact_prob * mask
36163625
36173626 # Section 5.7 equation 16
@@ -3652,7 +3661,7 @@ def compute_lddt(
36523661
36533662 # Restrict to bespoke inclusion radius
36543663 is_nucleotide = is_dna | is_rna
3655- is_nucleotide_pair = einx . logical_and ( 'b i, b j -> b i j' , is_nucleotide , is_nucleotide )
3664+ is_nucleotide_pair = to_pairwise_mask ( is_nucleotide )
36563665
36573666 inclusion_radius = torch .where (
36583667 is_nucleotide_pair ,
@@ -3665,7 +3674,7 @@ def compute_lddt(
36653674
36663675 # Take into account variable lengthed atoms in batch
36673676 if exists (coords_mask ):
3668- paired_coords_mask = einx . logical_and ( 'b i, b j -> b i j' , coords_mask , coords_mask )
3677+ paired_coords_mask = to_pairwise_mask ( coords_mask )
36693678 mask = mask & paired_coords_mask
36703679
36713680 mask = mask * pairwise_mask
@@ -3703,9 +3712,7 @@ def compute_chain_pair_lddt(
37033712
37043713 is_dna = is_molecule_types [..., IS_DNA_INDEX ]
37053714 is_rna = is_molecule_types [..., IS_RNA_INDEX ]
3706- pairwise_mask = einx .logical_and (
3707- 'b m, b n -> b m n' , asym_mask_a , asym_mask_b ,
3708- )
3715+ pairwise_mask = to_pairwise_mask (asym_mask_a )
37093716
37103717 lddt = self .compute_lddt (
37113718 pred_coords , true_coords , is_dna , is_rna , pairwise_mask , coords_mask
@@ -4342,7 +4349,7 @@ def forward(
43424349 # not to ligands + metal ions
43434350
43444351 is_chained_biomol = is_molecule_types [..., IS_BIOMOLECULE_INDICES ].any (dim = - 1 ) # first three types are chained biomolecules (protein, rna, dna)
4345- paired_is_chained_biomol = einx . logical_and ( 'b i, b j -> b i j' , is_chained_biomol , is_chained_biomol )
4352+ paired_is_chained_biomol = to_pairwise_mask ( is_chained_biomol )
43464353
43474354 relative_position_encoding = einx .where (
43484355 'b i j, b i j d, -> b i j d' ,
@@ -4374,7 +4381,7 @@ def forward(
43744381 # molecule mask and pairwise mask
43754382
43764383 mask = molecule_atom_lens > 0
4377- pairwise_mask = einx . logical_and ( 'b i, b j -> b i j' , mask , mask )
4384+ pairwise_mask = to_pairwise_mask ( mask )
43784385
43794386 # prepare mask for msa module and template embedder
43804387 # which is equivalent to the `is_protein` of the `is_molecular_types` input
@@ -4525,7 +4532,7 @@ def forward(
45254532
45264533 # account for representative distogram atom missing from residue (-1 set on distogram_atom_indices field)
45274534
4528- valid_distogram_mask = einx . logical_and ( 'b i, b j -> b i j' , valid_distogram_mask , valid_distogram_mask )
4535+ valid_distogram_mask = to_pairwise_mask ( valid_distogram_mask )
45294536 distance_labels .masked_fill_ (~ valid_distogram_mask , ignore )
45304537
45314538 if exists (distance_labels ):
@@ -4680,7 +4687,7 @@ def forward(
46804687 if self .confidence_head .atom_resolution :
46814688 label_mask = atom_mask
46824689
4683- label_pairwise_mask = einx . logical_and ( '... i, ... j -> ... i j' , label_mask , label_mask )
4690+ label_pairwise_mask = to_pairwise_mask ( label_mask )
46844691
46854692 # cross entropy losses
46864693
0 commit comments