@@ -2293,22 +2293,26 @@ def forward(
22932293 elif frame .ndim == 3 :
22942294 frame = rearrange (frame , 'b fr fc -> b 1 fr fc' )
22952295
2296- # Extract frame points
2297- a , b , c = frame .unbind (dim = - 1 )
2298-
2299- # Compute unit vectors of the frame
2300- e1 = F .normalize (a - b , dim = - 1 , eps = self .eps )
2301- e2 = F .normalize (c - b , dim = - 1 , eps = self .eps )
2302- e3 = torch .cross (e1 , e2 , dim = - 1 )
2303-
2304- # Express coordinates in the frame basis
2305- v = coords - b
2306-
2307- transformed_coords = torch .stack ([
2308- einsum (v , e1 , '... i, ... i -> ...' ),
2309- einsum (v , e2 , '... i, ... i -> ...' ),
2310- einsum (v , e3 , '... i, ... i -> ...' )
2311- ], dim = - 1 )
2296+ # Extract frame atoms
2297+ a , b , c = frame .unbind (dim = - 1 )
2298+ w1 = F .normalize (a - b , dim = - 1 , eps = self .eps )
2299+ w2 = F .normalize (c - b , dim = - 1 , eps = self .eps )
2300+
2301+ # Build orthonormal basis
2302+ e1 = F .normalize (w1 + w2 , dim = - 1 , eps = self .eps )
2303+ e2 = F .normalize (w2 - w1 , dim = - 1 , eps = self .eps )
2304+ e3 = torch .cross (e1 , e2 , dim = - 1 )
2305+
2306+ # Project onto frame basis
2307+ d = coords - b
2308+ transformed_coords = torch .stack (
2309+ [
2310+ einsum (d , e1 , '... i, ... i -> ...' ),
2311+ einsum (d , e2 , '... i, ... i -> ...' ),
2312+ einsum (d , e3 , '... i, ... i -> ...' ),
2313+ ],
2314+ dim = - 1 ,
2315+ )
23122316
23132317 return transformed_coords
23142318
0 commit comments