Skip to content

Commit 1e1251f

Browse files
authored
Using the rollout coordinates and stop gradient on representations for confidence head (#26)
* Using the predicted coordinates of the confidence head. * set default rollout steps to 20 * Stop gradient for representation and prediction of structure for the confidence head.
1 parent 5878227 commit 1e1251f

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,7 +3079,8 @@ def forward(
30793079
plddt_labels: Int['b n'] | None = None,
30803080
resolved_labels: Int['b n'] | None = None,
30813081
return_loss_breakdown = False,
3082-
return_loss_if_possible: bool = True
3082+
return_loss_if_possible: bool = True,
3083+
num_rollout_steps: int = 20,
30833084
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
30843085

30853086
atom_seq_len = atom_inputs.shape[-2]
@@ -3361,17 +3362,31 @@ def forward(
33613362
return_pae_logits = exists(pae_labels)
33623363

33633364
if calc_diffusion_loss and should_call_confidence_head:
3365+
3366+
# rollout
3367+
pred_atom_pos = self.edm.sample(
3368+
num_sample_steps = num_rollout_steps,
3369+
atom_feats = atom_feats,
3370+
atompair_feats = atompair_feats,
3371+
atom_mask = atom_mask,
3372+
mask = mask,
3373+
single_trunk_repr = single,
3374+
single_inputs_repr = single_inputs,
3375+
pairwise_trunk = pairwise,
3376+
pairwise_rel_pos_feats = relative_position_encoding,
3377+
residue_atom_lens = residue_atom_lens
3378+
)
33643379

33653380
if self.packed_atom_repr:
33663381
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
33673382
else:
33683383
pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
33693384

33703385
logits = self.confidence_head(
3371-
single_repr = single,
3372-
single_inputs_repr = single_inputs,
3373-
pairwise_repr = pairwise,
3374-
pred_atom_pos = pred_atom_pos,
3386+
single_repr = single.detach(),
3387+
single_inputs_repr = single_inputs.detach(),
3388+
pairwise_repr = pairwise.detach(),
3389+
pred_atom_pos = pred_atom_pos.detach(),
33753390
mask = mask,
33763391
return_pae_logits = return_pae_logits
33773392
)

0 commit comments

Comments
 (0)