7373
7474ADDITIONAL_RESIDUE_FEATS = 10
7575
76+ # threshold for checking that point cross-correlation
77+ # is full-rank in `WeightedRigidAlign`
78+ AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
79+
7680LinearNoBias = partial (Linear , bias = False )
7781
7882# helper functions
@@ -2210,6 +2214,7 @@ def forward(
22102214 weights : Float ['b n' ], # weights for each atom
22112215 mask : Bool ['b n' ] | None = None # mask for variable lengths
22122216 ) -> Float ['b n 3' ]:
2217+ batch_size , num_points , dim = pred_coords .shape
22132218
22142219 if exists (mask ):
22152220 # zero out all predicted and true coordinates where not an atom
@@ -2228,26 +2233,36 @@ def forward(
22282233 pred_coords_centered = pred_coords - pred_centroid
22292234 true_coords_centered = true_coords - true_centroid
22302235
2236+ if num_points < (dim + 1 ):
2237+ print (
2238+ "Warning: The size of one of the point clouds is <= dim+1. "
2239+ + "`WeightedRigidAlign` cannot return a unique rotation."
2240+ )
2241+
22312242 # Compute the weighted covariance matrix
2232- weighted_true_coords_center = true_coords_centered * weights
2233- cov_matrix = einsum (weighted_true_coords_center , pred_coords_centered , 'b n i, b n j -> b i j' )
2243+ cov_matrix = einsum (weights * true_coords_centered , pred_coords_centered , 'b n i, b n j -> b i j' )
22342244
22352245 # Compute the SVD of the covariance matrix
2236- U , _ , V = torch .svd (cov_matrix )
2246+ U , S , V = torch .svd (cov_matrix )
2247+
2248+ # Catch ambiguous rotation by checking the magnitude of singular values
2249+ if (S .abs () <= AMBIGUOUS_ROT_SINGULAR_THR ).any () and not (num_points < (dim + 1 )):
2250+ print (
2251+ "Warning: Excessively low rank of "
2252+ + "cross-correlation between aligned point clouds. "
2253+ + "`WeightedRigidAlign` cannot return a unique rotation."
2254+ )
22372255
22382256 # Compute the rotation matrix
22392257 rot_matrix = einsum (U , V , 'b i j, b k j -> b i k' )
22402258
22412259 # Ensure proper rotation matrix with determinant 1
2242- det = torch .det (rot_matrix )
2243- det_mask = det < 0
2244- V_fixed = V .clone ()
2245- V_fixed [det_mask , :, - 1 ] *= - 1
2246-
2247- rot_matrix [det_mask ] = einsum (U [det_mask ], V_fixed [det_mask ], 'b i j, b k j -> b i k' )
2260+ F = torch .eye (dim , dtype = cov_matrix .dtype , device = cov_matrix .device )[None ].repeat (batch_size , 1 , 1 )
2261+ F [:, - 1 , - 1 ] = torch .det (rot_matrix )
2262+ rot_matrix = einsum (U , F , V , "b i j, b j k, b l k -> b i l" )
22482263
22492264 # Apply the rotation and translation
2250- aligned_coords = einsum (pred_coords_centered , rot_matrix , 'b n i, b i j -> b n j' ) + true_centroid
2265+ aligned_coords = einsum (pred_coords_centered , rot_matrix , 'b n i, b j i -> b n j' ) + true_centroid
22512266 aligned_coords .detach_ ()
22522267
22532268 return aligned_coords
@@ -2367,7 +2382,7 @@ def forward(self, coords: Float['b n 3']) -> Float['b n 3']:
23672382 translation_vector = rearrange (translation_vector , 'b c -> b 1 c' )
23682383
23692384 # Apply rotation and translation
2370- augmented_coords = einsum (centered_coords , rotation_matrix , 'b n i, b i j -> b n j' ) + translation_vector
2385+ augmented_coords = einsum (centered_coords , rotation_matrix , 'b n i, b j i -> b n j' ) + translation_vector
23712386
23722387 return augmented_coords
23732388
0 commit comments