Skip to content

Commit 499fb2b

Browse files
committed
one tiny refactor to get ready for atom resolution confidence heads
1 parent d9e4fd4 commit 499fb2b

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2925,14 +2925,24 @@ def forward(
29252925
single_inputs_repr: Float['b n dsi'],
29262926
single_repr: Float['b n ds'],
29272927
pairwise_repr: Float['b n n dp'],
2928-
pred_atom_pos: Float['b n 3'],
2928+
pred_atom_pos: Float['b n 3'] | Float['b m 3'],
2929+
molecule_atom_indices: Int['b n'] | None = None,
29292930
mask: Bool['b n'] | None = None,
29302931
return_pae_logits = True
29312932

29322933
) -> ConfidenceHeadLogits:
29332934

29342935
pairwise_repr = pairwise_repr + self.single_inputs_to_pairwise(single_inputs_repr)
29352936

2937+
# pluck out the representative atoms for non-atomic resolution confidence head
2938+
2939+
is_atom_seq = pred_atom_pos.shape[-2] > single_inputs_repr.shape[-2]
2940+
2941+
assert not is_atom_seq or exists(molecule_atom_indices)
2942+
2943+
if is_atom_seq:
2944+
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', pred_atom_pos, molecule_atom_indices)
2945+
29362946
# interatomic distances - embed and add to pairwise
29372947

29382948
interatom_dist = torch.cdist(pred_atom_pos, pred_atom_pos, p = 2)
@@ -3640,8 +3650,7 @@ def forward(
36403650
sampled_atom_pos = einx.where('b m, b m c, -> b m c', atom_mask, sampled_atom_pos, 0.)
36413651

36423652
if return_confidence_head_logits:
3643-
assert exists(molecule_atom_indices)
3644-
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', sampled_atom_pos, molecule_atom_indices)
3653+
confidence_head_atom_pos_input = sampled_atom_pos.clone()
36453654

36463655
if exists(missing_atom_mask) and return_present_sampled_atoms:
36473656
sampled_atom_pos = sampled_atom_pos[~missing_atom_mask]
@@ -3653,7 +3662,8 @@ def forward(
36533662
single_repr = single.detach(),
36543663
single_inputs_repr = single_inputs.detach(),
36553664
pairwise_repr = pairwise.detach(),
3656-
pred_atom_pos = pred_atom_pos.detach(),
3665+
pred_atom_pos = confidence_head_atom_pos_input.detach(),
3666+
molecule_atom_indices = molecule_atom_indices,
36573667
mask = mask,
36583668
return_pae_logits = True
36593669
)
@@ -3822,13 +3832,12 @@ def forward(
38223832
tqdm_pbar_title = 'training rollout'
38233833
)
38243834

3825-
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', denoised_atom_pos, molecule_atom_indices)
3826-
38273835
ch_logits = self.confidence_head(
38283836
single_repr = single.detach(),
38293837
single_inputs_repr = single_inputs.detach(),
38303838
pairwise_repr = pairwise.detach(),
3831-
pred_atom_pos = pred_atom_pos.detach(),
3839+
pred_atom_pos = denoised_atom_pos.detach(),
3840+
molecule_atom_indices = molecule_atom_indices,
38323841
mask = mask,
38333842
return_pae_logits = return_pae_logits
38343843
)

0 commit comments

Comments
 (0)