116116t - templates
117117s - msa
118118r - registers
119+ ts - diffusion timesteps
119120"""
120121
121122"""
@@ -2670,8 +2671,10 @@ def sample(
26702671 clamp = False ,
26712672 use_tqdm_pbar = True ,
26722673 tqdm_pbar_title = 'sampling time step' ,
2674+ return_all_timesteps = False ,
26732675 ** network_condition_kwargs
2674- ):
2676+ ) -> Float ['b m 3' ] | Float ['ts b m 3' ]:
2677+
26752678 step_scale , num_sample_steps = self .step_scale , default (num_sample_steps , self .num_sample_steps )
26762679
26772680 shape = (* atom_mask .shape , 3 )
@@ -2702,6 +2705,8 @@ def sample(
27022705
27032706 maybe_augment_fn = self .centre_random_augmenter if self .augment_during_sampling else identity
27042707
2708+ all_atom_pos = [atom_pos ]
2709+
27052710 for sigma , sigma_next , gamma in maybe_tqdm_wrapper (sigmas_and_gammas , desc = tqdm_pbar_title ):
27062711 sigma , sigma_next , gamma = tuple (t .item () for t in (sigma , sigma_next , gamma ))
27072712
@@ -2726,6 +2731,14 @@ def sample(
27262731
27272732 atom_pos = atom_pos_next
27282733
2734+ all_atom_pos .append (atom_pos )
2735+
2736+ # if returning atom positions across all timesteps for visualization
2737+ # then stack the `all_atom_pos`
2738+
2739+ if return_all_timesteps :
2740+ atom_pos = torch .stack (all_atom_pos )
2741+
27292742 if clamp :
27302743 atom_pos = atom_pos .clamp (- 1. , 1. )
27312744
@@ -5437,6 +5450,7 @@ def forward(
54375450 resolution : Float [' b' ] | None = None ,
54385451 return_loss_breakdown = False ,
54395452 return_loss : bool = None ,
5453+ return_all_diffused_atom_pos : bool = False ,
54405454 return_confidence_head_logits : bool = False ,
54415455 return_distogram_head_logits : bool = False ,
54425456 num_rollout_steps : int | None = None ,
@@ -5447,7 +5461,8 @@ def forward(
54475461 hard_validate : bool = False
54485462 ) -> (
54495463 Float ['b m 3' ] |
5450- Tuple [Float ['b m 3' ], ConfidenceHeadLogits | Alphafold3Logits ] |
5464+ Float ['ts b m 3' ] |
5465+ Tuple [Float ['b m 3' ] | Float ['ts b m 3' ], ConfidenceHeadLogits | Alphafold3Logits ] |
54515466 Float ['' ] |
54525467 Tuple [Float ['' ], LossBreakdown ]
54535468 ):
@@ -5717,11 +5732,12 @@ def forward(
57175732 single_inputs_repr = single_inputs ,
57185733 pairwise_trunk = pairwise ,
57195734 pairwise_rel_pos_feats = relative_position_encoding ,
5720- molecule_atom_lens = molecule_atom_lens
5735+ molecule_atom_lens = molecule_atom_lens ,
5736+ return_all_timesteps = return_all_diffused_atom_pos
57215737 )
57225738
57235739 if exists (atom_mask ):
5724- sampled_atom_pos = einx .where ('b m, b m c, -> b m c' , atom_mask , sampled_atom_pos , 0. )
5740+ sampled_atom_pos = einx .where ('b m, ... b m c, -> ... b m c' , atom_mask , sampled_atom_pos , 0. )
57255741
57265742 if return_confidence_head_logits :
57275743 confidence_head_atom_pos_input = sampled_atom_pos .clone ()
0 commit comments