Skip to content

Commit dde0405

Browse files
committed
able to return all atom positions across diffusion timesteps for visualization purposes
1 parent 49bb8a5 commit dde0405

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
t - templates
117117
s - msa
118118
r - registers
119+
ts - diffusion timesteps
119120
"""
120121

121122
"""
@@ -2670,8 +2671,10 @@ def sample(
26702671
clamp = False,
26712672
use_tqdm_pbar = True,
26722673
tqdm_pbar_title = 'sampling time step',
2674+
return_all_timesteps = False,
26732675
**network_condition_kwargs
2674-
):
2676+
) -> Float['b m 3'] | Float['ts b m 3']:
2677+
26752678
step_scale, num_sample_steps = self.step_scale, default(num_sample_steps, self.num_sample_steps)
26762679

26772680
shape = (*atom_mask.shape, 3)
@@ -2702,6 +2705,8 @@ def sample(
27022705

27032706
maybe_augment_fn = self.centre_random_augmenter if self.augment_during_sampling else identity
27042707

2708+
all_atom_pos = [atom_pos]
2709+
27052710
for sigma, sigma_next, gamma in maybe_tqdm_wrapper(sigmas_and_gammas, desc = tqdm_pbar_title):
27062711
sigma, sigma_next, gamma = tuple(t.item() for t in (sigma, sigma_next, gamma))
27072712

@@ -2726,6 +2731,14 @@ def sample(
27262731

27272732
atom_pos = atom_pos_next
27282733

2734+
all_atom_pos.append(atom_pos)
2735+
2736+
# if returning atom positions across all timesteps for visualization
2737+
# then stack the `all_atom_pos`
2738+
2739+
if return_all_timesteps:
2740+
atom_pos = torch.stack(all_atom_pos)
2741+
27292742
if clamp:
27302743
atom_pos = atom_pos.clamp(-1., 1.)
27312744

@@ -5437,6 +5450,7 @@ def forward(
54375450
resolution: Float[' b'] | None = None,
54385451
return_loss_breakdown = False,
54395452
return_loss: bool = None,
5453+
return_all_diffused_atom_pos: bool = False,
54405454
return_confidence_head_logits: bool = False,
54415455
return_distogram_head_logits: bool = False,
54425456
num_rollout_steps: int | None = None,
@@ -5447,7 +5461,8 @@ def forward(
54475461
hard_validate: bool = False
54485462
) -> (
54495463
Float['b m 3'] |
5450-
Tuple[Float['b m 3'], ConfidenceHeadLogits | Alphafold3Logits] |
5464+
Float['ts b m 3'] |
5465+
Tuple[Float['b m 3'] | Float['ts b m 3'], ConfidenceHeadLogits | Alphafold3Logits] |
54515466
Float[''] |
54525467
Tuple[Float[''], LossBreakdown]
54535468
):
@@ -5717,11 +5732,12 @@ def forward(
57175732
single_inputs_repr = single_inputs,
57185733
pairwise_trunk = pairwise,
57195734
pairwise_rel_pos_feats = relative_position_encoding,
5720-
molecule_atom_lens = molecule_atom_lens
5735+
molecule_atom_lens = molecule_atom_lens,
5736+
return_all_timesteps = return_all_diffused_atom_pos
57215737
)
57225738

57235739
if exists(atom_mask):
5724-
sampled_atom_pos = einx.where('b m, b m c, -> b m c', atom_mask, sampled_atom_pos, 0.)
5740+
sampled_atom_pos = einx.where('b m, ... b m c, -> ... b m c', atom_mask, sampled_atom_pos, 0.)
57255741

57265742
if return_confidence_head_logits:
57275743
confidence_head_atom_pos_input = sampled_atom_pos.clone()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.3"
3+
version = "0.4.4"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)