Skip to content

Commit 57f6da2

Browse files
committed
centre random augmentation needs to be done for all batch samples separately
1 parent b93e278 commit 57f6da2

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,53 +2120,63 @@ class CentreRandomAugmentation(Module):
21202120
def __init__(self, trans_scale: float = 1.0):
21212121
super().__init__()
21222122
self.trans_scale = trans_scale
2123+
self.register_buffer('dummy', torch.tensor(0), persistent = False)
2124+
2125+
@property
2126+
def device(self):
2127+
return self.dummy.device
21232128

21242129
@typecheck
21252130
def forward(self, coords: Float['b n 3']) -> Float['b n 3']:
21262131
"""
21272132
coords: coordinates to be augmented
21282133
"""
2134+
batch_size = coords.shape[0]
2135+
21292136
# Center the coordinates
21302137
centered_coords = coords - coords.mean(dim=1, keepdim=True)
21312138

21322139
# Generate random rotation matrix
2133-
rotation_matrix = self._random_rotation_matrix(coords.device)
2140+
rotation_matrix = self._random_rotation_matrix(batch_size)
21342141

21352142
# Generate random translation vector
2136-
translation_vector = self._random_translation_vector(coords.device)
2143+
translation_vector = self._random_translation_vector(batch_size)
2144+
translation_vector = rearrange(translation_vector, 'b c -> b 1 c')
21372145

21382146
# Apply rotation and translation
2139-
augmented_coords = torch.einsum('bni,ij->bnj', centered_coords, rotation_matrix) + translation_vector
2147+
augmented_coords = einsum(centered_coords, rotation_matrix, 'b n i, b i j -> b n j') + translation_vector
21402148

21412149
return augmented_coords
21422150

21432151
@typecheck
2144-
def _random_rotation_matrix(self, device: torch.device) -> Float['3 3']:
2152+
def _random_rotation_matrix(self, batch_size: int) -> Float['b 3 3']:
21452153
# Generate random rotation angles
2146-
angles = torch.rand(3, device=device) * 2 * torch.pi
2154+
angles = torch.rand((batch_size, 3), device = self.device) * 2 * torch.pi
21472155

21482156
# Compute sine and cosine of angles
21492157
sin_angles = torch.sin(angles)
21502158
cos_angles = torch.cos(angles)
21512159

21522160
# Construct rotation matrix
2153-
rotation_matrix = torch.eye(3, device=device)
2154-
rotation_matrix[0, 0] = cos_angles[0] * cos_angles[1]
2155-
rotation_matrix[0, 1] = cos_angles[0] * sin_angles[1] * sin_angles[2] - sin_angles[0] * cos_angles[2]
2156-
rotation_matrix[0, 2] = cos_angles[0] * sin_angles[1] * cos_angles[2] + sin_angles[0] * sin_angles[2]
2157-
rotation_matrix[1, 0] = sin_angles[0] * cos_angles[1]
2158-
rotation_matrix[1, 1] = sin_angles[0] * sin_angles[1] * sin_angles[2] + cos_angles[0] * cos_angles[2]
2159-
rotation_matrix[1, 2] = sin_angles[0] * sin_angles[1] * cos_angles[2] - cos_angles[0] * sin_angles[2]
2160-
rotation_matrix[2, 0] = -sin_angles[1]
2161-
rotation_matrix[2, 1] = cos_angles[1] * sin_angles[2]
2162-
rotation_matrix[2, 2] = cos_angles[1] * cos_angles[2]
2161+
eye = torch.eye(3, device = self.device)
2162+
rotation_matrix = repeat(eye, 'i j -> b i j', b = batch_size).clone()
2163+
2164+
rotation_matrix[:, 0, 0] = cos_angles[:, 0] * cos_angles[:, 1]
2165+
rotation_matrix[:, 0, 1] = cos_angles[:, 0] * sin_angles[:, 1] * sin_angles[:, 2] - sin_angles[:, 0] * cos_angles[:, 2]
2166+
rotation_matrix[:, 0, 2] = cos_angles[:, 0] * sin_angles[:, 1] * cos_angles[:, 2] + sin_angles[:, 0] * sin_angles[:, 2]
2167+
rotation_matrix[:, 1, 0] = sin_angles[:, 0] * cos_angles[:, 1]
2168+
rotation_matrix[:, 1, 1] = sin_angles[:, 0] * sin_angles[:, 1] * sin_angles[:, 2] + cos_angles[:, 0] * cos_angles[:, 2]
2169+
rotation_matrix[:, 1, 2] = sin_angles[:, 0] * sin_angles[:, 1] * cos_angles[:, 2] - cos_angles[:, 0] * sin_angles[:, 2]
2170+
rotation_matrix[:, 2, 0] = -sin_angles[:, 1]
2171+
rotation_matrix[:, 2, 1] = cos_angles[:, 1] * sin_angles[:, 2]
2172+
rotation_matrix[:, 2, 2] = cos_angles[:, 1] * cos_angles[:, 2]
21632173

21642174
return rotation_matrix
21652175

21662176
@typecheck
2167-
def _random_translation_vector(self, device: torch.device) -> Float['3']:
2177+
def _random_translation_vector(self, batch_size: int) -> Float['b 3']:
21682178
# Generate random translation vector
2169-
translation_vector = torch.randn(3, device=device) * self.trans_scale
2179+
translation_vector = torch.randn((batch_size, 3), device = self.device) * self.trans_scale
21702180
return translation_vector
21712181

21722182
# input embedder

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.24"
3+
version = "0.0.25"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)