@@ -2120,53 +2120,63 @@ class CentreRandomAugmentation(Module):
21202120 def __init__ (self , trans_scale : float = 1.0 ):
21212121 super ().__init__ ()
21222122 self .trans_scale = trans_scale
2123+ self .register_buffer ('dummy' , torch .tensor (0 ), persistent = False )
2124+
2125+ @property
2126+ def device (self ):
2127+ return self .dummy .device
21232128
21242129 @typecheck
21252130 def forward (self , coords : Float ['b n 3' ]) -> Float ['b n 3' ]:
21262131 """
21272132 coords: coordinates to be augmented
21282133 """
2134+ batch_size = coords .shape [0 ]
2135+
21292136 # Center the coordinates
21302137 centered_coords = coords - coords .mean (dim = 1 , keepdim = True )
21312138
21322139 # Generate random rotation matrix
2133- rotation_matrix = self ._random_rotation_matrix (coords . device )
2140+ rotation_matrix = self ._random_rotation_matrix (batch_size )
21342141
21352142 # Generate random translation vector
2136- translation_vector = self ._random_translation_vector (coords .device )
2143+ translation_vector = self ._random_translation_vector (batch_size )
2144+ translation_vector = rearrange (translation_vector , 'b c -> b 1 c' )
21372145
21382146 # Apply rotation and translation
2139- augmented_coords = torch . einsum ('bni,ij->bnj' , centered_coords , rotation_matrix ) + translation_vector
2147+ augmented_coords = einsum (centered_coords , rotation_matrix , 'b n i, b i j -> b n j' ) + translation_vector
21402148
21412149 return augmented_coords
21422150
21432151 @typecheck
2144- def _random_rotation_matrix (self , device : torch . device ) -> Float ['3 3' ]:
2152+ def _random_rotation_matrix (self , batch_size : int ) -> Float ['b 3 3' ]:
21452153 # Generate random rotation angles
2146- angles = torch .rand (3 , device = device ) * 2 * torch .pi
2154+ angles = torch .rand (( batch_size , 3 ) , device = self . device ) * 2 * torch .pi
21472155
21482156 # Compute sine and cosine of angles
21492157 sin_angles = torch .sin (angles )
21502158 cos_angles = torch .cos (angles )
21512159
21522160 # Construct rotation matrix
2153- rotation_matrix = torch .eye (3 , device = device )
2154- rotation_matrix [0 , 0 ] = cos_angles [0 ] * cos_angles [1 ]
2155- rotation_matrix [0 , 1 ] = cos_angles [0 ] * sin_angles [1 ] * sin_angles [2 ] - sin_angles [0 ] * cos_angles [2 ]
2156- rotation_matrix [0 , 2 ] = cos_angles [0 ] * sin_angles [1 ] * cos_angles [2 ] + sin_angles [0 ] * sin_angles [2 ]
2157- rotation_matrix [1 , 0 ] = sin_angles [0 ] * cos_angles [1 ]
2158- rotation_matrix [1 , 1 ] = sin_angles [0 ] * sin_angles [1 ] * sin_angles [2 ] + cos_angles [0 ] * cos_angles [2 ]
2159- rotation_matrix [1 , 2 ] = sin_angles [0 ] * sin_angles [1 ] * cos_angles [2 ] - cos_angles [0 ] * sin_angles [2 ]
2160- rotation_matrix [2 , 0 ] = - sin_angles [1 ]
2161- rotation_matrix [2 , 1 ] = cos_angles [1 ] * sin_angles [2 ]
2162- rotation_matrix [2 , 2 ] = cos_angles [1 ] * cos_angles [2 ]
2161+ eye = torch .eye (3 , device = self .device )
2162+ rotation_matrix = repeat (eye , 'i j -> b i j' , b = batch_size ).clone ()
2163+
2164+ rotation_matrix [:, 0 , 0 ] = cos_angles [:, 0 ] * cos_angles [:, 1 ]
2165+ rotation_matrix [:, 0 , 1 ] = cos_angles [:, 0 ] * sin_angles [:, 1 ] * sin_angles [:, 2 ] - sin_angles [:, 0 ] * cos_angles [:, 2 ]
2166+ rotation_matrix [:, 0 , 2 ] = cos_angles [:, 0 ] * sin_angles [:, 1 ] * cos_angles [:, 2 ] + sin_angles [:, 0 ] * sin_angles [:, 2 ]
2167+ rotation_matrix [:, 1 , 0 ] = sin_angles [:, 0 ] * cos_angles [:, 1 ]
2168+ rotation_matrix [:, 1 , 1 ] = sin_angles [:, 0 ] * sin_angles [:, 1 ] * sin_angles [:, 2 ] + cos_angles [:, 0 ] * cos_angles [:, 2 ]
2169+ rotation_matrix [:, 1 , 2 ] = sin_angles [:, 0 ] * sin_angles [:, 1 ] * cos_angles [:, 2 ] - cos_angles [:, 0 ] * sin_angles [:, 2 ]
2170+ rotation_matrix [:, 2 , 0 ] = - sin_angles [:, 1 ]
2171+ rotation_matrix [:, 2 , 1 ] = cos_angles [:, 1 ] * sin_angles [:, 2 ]
2172+ rotation_matrix [:, 2 , 2 ] = cos_angles [:, 1 ] * cos_angles [:, 2 ]
21632173
21642174 return rotation_matrix
21652175
21662176 @typecheck
2167- def _random_translation_vector (self , device : torch . device ) -> Float ['3' ]:
2177+ def _random_translation_vector (self , batch_size : int ) -> Float ['b 3' ]:
21682178 # Generate random translation vector
2169- translation_vector = torch .randn (3 , device = device ) * self .trans_scale
2179+ translation_vector = torch .randn (( batch_size , 3 ) , device = self . device ) * self .trans_scale
21702180 return translation_vector
21712181
21722182# input embedder
0 commit comments