@@ -1847,7 +1847,8 @@ def forward(
18471847 atom_pos_aligned_ground_truth = self .weighted_rigid_align (
18481848 atom_pos_ground_truth ,
18491849 denoised_atom_pos ,
1850- align_weights
1850+ align_weights ,
1851+ mask = atom_mask
18511852 )
18521853
18531854 # main diffusion mse loss
@@ -1932,10 +1933,10 @@ def forward(
19321933 coords_mask : Bool ['b n' ] | None = None ,
19331934 ) -> Float ['' ]:
19341935 """
1935- pred_coords: predicted coordinates (b, n, 3)
1936- true_coords: true coordinates (b, n, 3)
1937- is_dna: boolean tensor indicating DNA atoms (b, n)
1938- is_rna: boolean tensor indicating RNA atoms (b, n)
1936+ pred_coords: predicted coordinates
1937+ true_coords: true coordinates
1938+ is_dna: boolean tensor indicating DNA atoms
1939+ is_rna: boolean tensor indicating RNA atoms
19391940 """
19401941 # Compute distances between all pairs of atoms
19411942 pred_dists = torch .cdist (pred_coords , pred_coords )
@@ -1954,15 +1955,16 @@ def forward(
19541955
19551956 # Restrict to bespoke inclusion radius
19561957 is_nucleotide = is_dna | is_rna
1957- is_nucleotide_pair = is_nucleotide .unsqueeze (- 1 ) & is_nucleotide .unsqueeze (- 2 )
1958+ is_nucleotide_pair = einx .logical_and ('b i, b j -> b i j' , is_nucleotide , is_nucleotide )
1959+
19581960 inclusion_radius = torch .where (
19591961 is_nucleotide_pair ,
19601962 true_dists < self .nucleic_acid_cutoff ,
19611963 true_dists < self .other_cutoff
19621964 )
19631965
19641966 # Compute mean, avoiding self term
1965- mask = torch . logical_and ( inclusion_radius , torch .logical_not ( torch . eye (pred_coords .shape [1 ], dtype = torch .bool , device = pred_coords .device )) )
1967+ mask = inclusion_radius & ~ torch .eye (pred_coords .shape [1 ], dtype = torch .bool , device = pred_coords .device )
19661968
19671969 # Take into account variable lengthed atoms in batch
19681970 if exists (coords_mask ):
@@ -1974,59 +1976,68 @@ def forward(
19741976 lddt_count = mask .sum (dim = (- 1 , - 2 ))
19751977 lddt = lddt_sum / lddt_count .clamp (min = 1 )
19761978
1977- return 1 - lddt .mean ()
1979+ return 1. - lddt .mean ()
19781980
19791981class WeightedRigidAlign (Module ):
19801982 """ Algorithm 28 """
1981- def __init__ (self ):
1982- super ().__init__ ()
19831983
19841984 @typecheck
19851985 def forward (
19861986 self ,
1987- pred_coords : Float ['b n 3' ],
1988- true_coords : Float ['b n 3' ],
1989- weights : Float ['b n' ]
1987+ pred_coords : Float ['b n 3' ], # predicted coordinates
1988+ true_coords : Float ['b n 3' ], # true coordinates
1989+ weights : Float ['b n' ], # weights for each atom
1990+ mask : Bool ['b n' ] | None = None # mask for variable lengths
19901991 ) -> Float ['b n 3' ]:
1991- """
1992- pred_coords: predicted coordinates (b, n, 3)
1993- true_coords: true coordinates (b, n, 3)
1994- weights: weights for each atom (b, n)
1995- """
1992+
1993+ if exists (mask ):
1994+ # zero out all predicted and true coordinates where not an atom
1995+ pred_coords = einx .where ('b n, b n c, -> b n c' , mask , pred_coords , 0. )
1996+ true_coords = einx .where ('b n, b n c, -> b n c' , mask , true_coords , 0. )
1997+ weights = einx .where ('b n, b n, -> b n' , mask , weights , 0. )
1998+
1999+ # Take care of weights broadcasting for coordinate dimension
2000+ weights = rearrange (weights , 'b n -> b n 1' )
19962001
19972002 # Compute weighted centroids
1998- pred_centroid = (pred_coords * weights . unsqueeze ( - 1 )) .sum (dim = 1 ) / weights .sum (dim = 1 , keepdim = True )
1999- true_centroid = (true_coords * weights . unsqueeze ( - 1 )) .sum (dim = 1 ) / weights .sum (dim = 1 , keepdim = True )
2003+ pred_centroid = (pred_coords * weights ) .sum (dim = 1 , keepdim = True ) / weights .sum (dim = 1 , keepdim = True )
2004+ true_centroid = (true_coords * weights ) .sum (dim = 1 , keepdim = True ) / weights .sum (dim = 1 , keepdim = True )
20002005
20012006 # Center the coordinates
2002- pred_coords_centered = pred_coords - pred_centroid . unsqueeze ( 1 )
2003- true_coords_centered = true_coords - true_centroid . unsqueeze ( 1 )
2007+ pred_coords_centered = pred_coords - pred_centroid
2008+ true_coords_centered = true_coords - true_centroid
20042009
20052010 # Compute the weighted covariance matrix
2006- cov_matrix = torch .einsum ('bni,bnj->bij' , true_coords_centered * weights .unsqueeze (- 1 ), pred_coords_centered )
2011+ weighted_true_coords_center = true_coords_centered * weights
2012+ cov_matrix = einsum (weighted_true_coords_center , pred_coords_centered , 'b n i, b n j -> b i j' )
20072013
20082014 # Compute the SVD of the covariance matrix
20092015 U , _ , V = torch .svd (cov_matrix )
20102016
20112017 # Compute the rotation matrix
2012- rot_matrix = torch . einsum ('bij,bjk->bik' , U , V )
2018+ rot_matrix = einsum (U , V , 'b i j, b j k -> b i k' )
20132019
20142020 # Ensure proper rotation matrix with determinant 1
20152021 det = torch .det (rot_matrix )
20162022 det_mask = det < 0
20172023 V_fixed = V .clone ()
20182024 V_fixed [det_mask , :, - 1 ] *= - 1
2019- rot_matrix [det_mask ] = torch .einsum ('bij,bjk->bik' , U [det_mask ], V_fixed [det_mask ])
2025+
2026+ rot_matrix [det_mask ] = einsum (U [det_mask ], V_fixed [det_mask ], 'b i j, b j k -> b i k' )
20202027
20212028 # Apply the rotation and translation
2022- aligned_coords = torch .einsum ('bni,bij->bnj' , pred_coords_centered , rot_matrix ) + true_centroid .unsqueeze (1 )
2029+ aligned_coords = einsum (pred_coords_centered , rot_matrix , 'b n i, b i j -> b n j' ) + true_centroid
2030+ aligned_coords .detach_ ()
20232031
2024- return aligned_coords . detach ()
2032+ return aligned_coords
20252033
20262034class ExpressCoordinatesInFrame (Module ):
20272035 """ Algorithm 29 """
20282036
2029- def __init__ (self , eps = 1e-8 ):
2037+ def __init__ (
2038+ self ,
2039+ eps = 1e-8
2040+ ):
20302041 super ().__init__ ()
20312042 self .eps = eps
20322043
@@ -2037,8 +2048,8 @@ def forward(
20372048 frame : Float ['b m 3 3' ] | Float ['b 3 3' ] | Float ['3 3' ]
20382049 ) -> Float ['b m 3' ]:
20392050 """
2040- coords: coordinates to be expressed in the given frame (b, 3)
2041- frame: frame defined by three points (b, 3, 3)
2051+ coords: coordinates to be expressed in the given frame
2052+ frame: frame defined by three points
20422053 """
20432054
20442055 if frame .ndim == 2 :
@@ -2067,8 +2078,12 @@ def forward(
20672078
20682079class ComputeAlignmentError (Module ):
20692080 """ Algorithm 30 """
2081+
20702082 @typecheck
2071- def __init__ (self , eps : float = 1e-8 ):
2083+ def __init__ (
2084+ self ,
2085+ eps : float = 1e-8
2086+ ):
20722087 super ().__init__ ()
20732088 self .eps = eps
20742089 self .express_coordinates_in_frame = ExpressCoordinatesInFrame ()
@@ -2082,10 +2097,10 @@ def forward(
20822097 true_frames : Float ['b n 3 3' ]
20832098 ) -> Float ['b n' ]:
20842099 """
2085- pred_coords: predicted coordinates (b, n, 3)
2086- true_coords: true coordinates (b, n, 3)
2087- pred_frames: predicted frames (b, n, 3, 3)
2088- true_frames: true frames (b, n, 3, 3)
2100+ pred_coords: predicted coordinates
2101+ true_coords: true coordinates
2102+ pred_frames: predicted frames
2103+ true_frames: true frames
20892104 """
20902105 # Express predicted coordinates in predicted frames
20912106 pred_coords_transformed = self .express_coordinates_in_frame (pred_coords , pred_frames )
@@ -2102,6 +2117,7 @@ def forward(
21022117
21032118class CentreRandomAugmentation (Module ):
21042119 """ Algorithm 19 """
2120+
21052121 @typecheck
21062122 def __init__ (self , trans_scale : float = 1.0 ):
21072123 super ().__init__ ()
@@ -2110,7 +2126,7 @@ def __init__(self, trans_scale: float = 1.0):
21102126 @typecheck
21112127 def forward (self , coords : Float ['b n 3' ]) -> Float ['b n 3' ]:
21122128 """
2113- coords: coordinates to be augmented (b, n, 3)
2129+ coords: coordinates to be augmented
21142130 """
21152131 # Center the coordinates
21162132 centered_coords = coords - coords .mean (dim = 1 , keepdim = True )
0 commit comments