Skip to content

Commit d306353

Browse files
authored
Update alphafold3.py (#199)
1 parent 99799e9 commit d306353

File tree

1 file changed

+61
-34
lines changed

1 file changed

+61
-34
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,8 @@ def forward(
935935
mask: Bool['b n'] | None = None,
936936
msa_mask: Bool['b s'] | None = None
937937
) -> Float['b n n dp']:
938+
939+
dtype = msa.dtype
938940

939941
msa = self.norm(msa)
940942

@@ -945,12 +947,12 @@ def forward(
945947
# maybe masked mean for outer product
946948

947949
if exists(msa_mask):
948-
a = einx.multiply('b s i d, b s -> b s i d', a, msa_mask.float())
949-
b = einx.multiply('b s j e, b s -> b s j e', b, msa_mask.float())
950+
a = einx.multiply('b s i d, b s -> b s i d', a, msa_mask.type(dtype))
951+
b = einx.multiply('b s j e, b s -> b s j e', b, msa_mask.type(dtype))
950952

951953
outer_product = einsum(a, b, 'b s i d, b s j e -> b i j d e')
952954

953-
num_msa = reduce(msa_mask.float(), '... s -> ...', 'sum')
955+
num_msa = reduce(msa_mask.type(dtype), '... s -> ...', 'sum')
954956

955957
outer_product_mean = einx.divide('b i j d e, b', outer_product, num_msa.clamp(min = self.eps))
956958
else:
@@ -966,7 +968,9 @@ def forward(
966968

967969
if exists(mask):
968970
mask = to_pairwise_mask(mask)
969-
outer_product_mean = einx.multiply('b i j d, b i j', outer_product_mean, mask.float())
971+
outer_product_mean = einx.multiply(
972+
'b i j d, b i j', outer_product_mean, mask.type(dtype)
973+
)
970974

971975
pairwise_repr = self.to_pairwise_repr(outer_product_mean)
972976
return pairwise_repr
@@ -1520,6 +1524,7 @@ def forward(
15201524
additional_molecule_feats: Int[f'b n {ADDITIONAL_MOLECULE_FEATS}']
15211525
) -> Float['b n n dp']:
15221526

1527+
dtype = self.out_embedder.weight.dtype
15231528
device = additional_molecule_feats.device
15241529

15251530
res_idx, token_idx, asym_id, entity_id, sym_id = additional_molecule_feats.unbind(dim = -1)
@@ -1554,7 +1559,7 @@ def onehot(x, bins):
15541559
dist_from_bins = einx.subtract('... i, j -> ... i j', x, bins)
15551560
indices = dist_from_bins.abs().min(dim = -1, keepdim = True).indices
15561561
one_hots = F.one_hot(indices.long(), num_classes = len(bins))
1557-
return one_hots.float()
1562+
return one_hots.type(dtype)
15581563

15591564
r_arange = torch.arange(2*self.r_max + 2, device = device)
15601565
s_arange = torch.arange(2*self.s_max + 2, device = device)
@@ -1675,6 +1680,7 @@ def forward(
16751680
mask: Bool['b n'] | None = None,
16761681
) -> Float['b n n dp']:
16771682

1683+
dtype = templates.dtype
16781684
num_templates = templates.shape[1]
16791685

16801686
pairwise_repr = self.pairwise_to_embed_input(pairwise_repr)
@@ -1714,7 +1720,7 @@ def forward(
17141720
)
17151721

17161722
num = reduce(templates, 'b t i j d -> b i j d', 'sum')
1717-
den = reduce(template_mask.float(), 'b t -> b', 'sum')
1723+
den = reduce(template_mask.type(dtype), 'b t -> b', 'sum')
17181724

17191725
avg_template_repr = einx.divide('b i j d, b -> b i j d', num, den.clamp(min = self.eps))
17201726

@@ -2612,6 +2618,10 @@ def __init__(
26122618
def device(self):
26132619
return next(self.net.parameters()).device
26142620

2621+
@property
2622+
def dtype(self):
2623+
return next(self.net.parameters()).dtype
2624+
26152625
# derived preconditioning params - Table 1
26162626

26172627
def c_skip(self, sigma):
@@ -2637,10 +2647,14 @@ def preconditioned_network_forward(
26372647
network_condition_kwargs: dict,
26382648
clamp = False,
26392649
):
2640-
batch, device = noised_atom_pos.shape[0], noised_atom_pos.device
2650+
batch, dtype, device = (
2651+
noised_atom_pos.shape[0],
2652+
noised_atom_pos.dtype,
2653+
noised_atom_pos.device,
2654+
)
26412655

26422656
if isinstance(sigma, float):
2643-
sigma = torch.full((batch,), sigma, device = device)
2657+
sigma = torch.full((batch,), sigma, dtype=dtype, device=device)
26442658

26452659
padded_sigma = rearrange(sigma, 'b -> b 1 1')
26462660

@@ -2668,7 +2682,7 @@ def sample_schedule(self, num_sample_steps = None):
26682682
N = num_sample_steps
26692683
inv_rho = 1 / self.rho
26702684

2671-
steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
2685+
steps = torch.arange(num_sample_steps, device=self.device, dtype=self.dtype)
26722686
sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho
26732687

26742688
sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
@@ -2687,6 +2701,8 @@ def sample(
26872701
**network_condition_kwargs
26882702
) -> Float['b m 3'] | Float['ts b m 3']:
26892703

2704+
dtype = self.dtype
2705+
26902706
step_scale, num_sample_steps = self.step_scale, default(num_sample_steps, self.num_sample_steps)
26912707

26922708
shape = (*atom_mask.shape, 3)
@@ -2709,7 +2725,7 @@ def sample(
27092725

27102726
init_sigma = sigmas[0]
27112727

2712-
atom_pos = init_sigma * torch.randn(shape, device = self.device)
2728+
atom_pos = init_sigma * torch.randn(shape, dtype = dtype, device = self.device)
27132729

27142730
# gradually denoise
27152731

@@ -2722,9 +2738,11 @@ def sample(
27222738
for sigma, sigma_next, gamma in maybe_tqdm_wrapper(sigmas_and_gammas, desc = tqdm_pbar_title):
27232739
sigma, sigma_next, gamma = tuple(t.item() for t in (sigma, sigma_next, gamma))
27242740

2725-
atom_pos = maybe_augment_fn(atom_pos)
2741+
atom_pos = maybe_augment_fn(atom_pos.float()).type(dtype)
27262742

2727-
eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
2743+
eps = self.S_noise * torch.randn(
2744+
shape, dtype = dtype, device = self.device
2745+
) # stochastic sampling
27282746

27292747
sigma_hat = sigma + gamma * sigma
27302748
atom_pos_hat = atom_pos + sqrt(sigma_hat ** 2 - sigma ** 2) * eps
@@ -2797,9 +2815,10 @@ def forward(
27972815

27982816
# diffusion loss
27992817

2818+
dtype = atom_pos_ground_truth.dtype
28002819
batch_size = atom_pos_ground_truth.shape[0]
28012820

2802-
sigmas = self.noise_distribution(batch_size)
2821+
sigmas = self.noise_distribution(batch_size).type(dtype)
28032822
padded_sigmas = rearrange(sigmas, 'b -> b 1 1')
28042823

28052824
noise = torch.randn_like(atom_pos_ground_truth)
@@ -2839,11 +2858,11 @@ def forward(
28392858
)
28402859

28412860
atom_pos_aligned_ground_truth = self.weighted_rigid_align(
2842-
pred_coords=denoised_atom_pos,
2843-
true_coords=atom_pos_ground_truth,
2844-
weights=align_weights,
2861+
pred_coords=denoised_atom_pos.float(),
2862+
true_coords=atom_pos_ground_truth.float(),
2863+
weights=align_weights.float(),
28452864
mask=atom_mask,
2846-
)
2865+
).type(dtype)
28472866

28482867
# section 4.2 - multi-chain permutation alignment
28492868

@@ -3375,7 +3394,9 @@ def calculate_optimal_transform(
33753394
selected anchor truth as well as a matrix that records how the atoms should be shifted after applying `r`.
33763395
N.b., Optimal alignment requires 1) a rotation and 2) a shift of the positions.
33773396
"""
3397+
dtype = pred_pos.dtype
33783398
batch_size = pred_pos.shape[0]
3399+
33793400
input_mask = self.calculate_input_mask(
33803401
true_masks=true_masks,
33813402
anchor_gt_idx=anchor_gt_idx,
@@ -3389,13 +3410,13 @@ def calculate_optimal_transform(
33893410
b=batch_size,
33903411
)
33913412
_, r, x = self.weighted_rigid_align(
3392-
pred_coords=anchor_pred_pos,
3393-
true_coords=anchor_true_pos,
3413+
pred_coords=anchor_pred_pos.float(),
3414+
true_coords=anchor_true_pos.float(),
33943415
mask=input_mask,
33953416
return_transforms=True,
33963417
)
33973418

3398-
return r, x
3419+
return r.type(dtype), x.type(dtype)
33993420

34003421
@staticmethod
34013422
@typecheck
@@ -4498,7 +4519,9 @@ def compute_plddt(
44984519
logits = rearrange(logits, "b plddt m -> b m plddt")
44994520
num_bins = logits.shape[-1]
45004521
bin_width = 1.0 / num_bins
4501-
bin_centers = torch.arange(0.5 * bin_width, 1.0, bin_width, device=logits.device)
4522+
bin_centers = torch.arange(
4523+
0.5 * bin_width, 1.0, bin_width, dtype=logits.dtype, device=logits.device
4524+
)
45024525
probs = F.softmax(logits, dim=-1)
45034526

45044527
predicted_lddt = einsum(probs, bin_centers, "b m plddt, plddt -> b m")
@@ -5177,13 +5200,15 @@ def compute_gpde(
51775200
:return: [b] global PDE
51785201
"""
51795202

5203+
dtype = pde_logits.dtype
5204+
51805205
pde = self.compute_confidence_score.compute_pde(pde_logits, tok_repr_atm_mask)
51815206

51825207
dist_logits = rearrange(dist_logits, "b dist i j -> b i j dist")
51835208
dist_probs = F.softmax(dist_logits, dim=-1)
51845209

51855210
# for distances greater than the last breaks
5186-
dist_breaks = F.pad(dist_breaks, (0, 1), value=1e6)
5211+
dist_breaks = F.pad(dist_breaks.float(), (0, 1), value=1e6).type(dtype)
51875212
contact_mask = dist_breaks < self.contact_mask_threshold
51885213

51895214
contact_prob = einx.where(
@@ -5219,6 +5244,7 @@ def compute_lddt(
52195244
:return: lDDT
52205245
"""
52215246

5247+
dtype = pred_coords.dtype
52225248
atom_seq_len, device = pred_coords.shape[1], pred_coords.device
52235249

52245250
# Compute distances between all pairs of atoms
@@ -5229,7 +5255,7 @@ def compute_lddt(
52295255
dist_diff = torch.abs(true_dists - pred_dists)
52305256

52315257
lddt = einx.subtract('thresholds, ... -> ... thresholds', self.lddt_thresholds, dist_diff)
5232-
lddt = (lddt >= 0).float().mean(dim = -1)
5258+
lddt = (lddt >= 0).type(dtype).mean(dim=-1)
52335259

52345260
# Restrict to bespoke inclusion radius
52355261
is_nucleotide = is_dna | is_rna
@@ -6267,6 +6293,8 @@ def forward(
62676293
atom_seq_len = atom_inputs.shape[-2]
62686294
single_structure_input = atom_inputs.shape[0] == 1
62696295

6296+
dtype = atom_inputs.dtype
6297+
62706298
# validate atom and atompair input dimensions
62716299

62726300
assert atom_inputs.shape[-1] == self.dim_atom_inputs, f'expected {self.dim_atom_inputs} for atom_inputs feature dimension, but received {atom_inputs.shape[-1]}'
@@ -6420,7 +6448,7 @@ def forward(
64206448
seq_arange = torch.arange(seq_len, device = self.device)
64216449
token_bonds = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1
64226450

6423-
token_bonds_feats = self.token_bond_to_pairwise_feat(token_bonds.float())
6451+
token_bonds_feats = self.token_bond_to_pairwise_feat(token_bonds.type(dtype))
64246452

64256453
pairwise_init = pairwise_init + token_bonds_feats
64266454

@@ -6711,13 +6739,12 @@ def forward(
67116739
fa_atom_mask, aug_atom_mask = atom_mask[:1], atom_mask[1:]
67126740

67136741
fa_atom_pos = self.frame_average(
6714-
fa_atom_pos,
6715-
frame_average_mask = fa_atom_mask
6716-
)
6742+
fa_atom_pos.float(), frame_average_mask = fa_atom_mask
6743+
).type(dtype)
67176744

67186745
# normal random augmentations, 48 times in paper
67196746

6720-
atom_pos = self.augmenter(atom_pos, mask = aug_atom_mask)
6747+
atom_pos = self.augmenter(atom_pos.float(), mask = aug_atom_mask).type(dtype)
67216748

67226749
# concat back the stochastic frame averaged position
67236750

@@ -6793,11 +6820,11 @@ def forward(
67936820

67946821
try:
67956822
atom_pos = self.weighted_rigid_align(
6796-
pred_coords=denoised_atom_pos,
6797-
true_coords=atom_pos,
6798-
weights=align_weights,
6823+
pred_coords=denoised_atom_pos.float(),
6824+
true_coords=atom_pos.float(),
6825+
weights=align_weights.float(),
67996826
mask=atom_mask,
6800-
)
6827+
).type(dtype)
68016828
except Exception as e:
68026829
# NOTE: For many (random) unit test inputs, weighted rigid alignment can be unstable
68036830
logger.warning(f"Skipping weighted rigid alignment due to: {e}")
@@ -7009,7 +7036,7 @@ def forward(
70097036
lddt = einx.subtract(
70107037
"thresholds, ... -> ... thresholds", self.lddt_thresholds, dist_diff
70117038
)
7012-
lddt = (lddt >= 0).float().mean(dim=-1)
7039+
lddt = (lddt >= 0).type(dtype).mean(dim=-1)
70137040

70147041
# calculate masked averaging,
70157042
# after which we assign each value to one of 50 equally sized bins
@@ -7047,7 +7074,7 @@ def forward(
70477074
else torch.full((batch_size,), False, device=self.device)
70487075
)
70497076

7050-
confidence_weight = confidence_mask.float()
7077+
confidence_weight = confidence_mask.type(dtype)
70517078

70527079
@typecheck
70537080
def cross_entropy_with_weight(

0 commit comments

Comments
 (0)