Skip to content

Commit 6b643db

Browse files
committed
able to pass in initial valid_atom_indices_for_frame
1 parent fb0ca98 commit 6b643db

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5277,6 +5277,7 @@ def forward(
52775277
atom_mask: Bool['b m'] | None = None,
52785278
missing_atom_mask: Bool['b m'] | None = None,
52795279
atom_indices_for_frame: Int['b n 3'] | None = None,
5280+
valid_atom_indices_for_frame: Bool['b n'] | None = None,
52805281
atom_parent_ids: Int['b m'] | None = None,
52815282
token_bonds: Bool['b n n'] | None = None,
52825283
msa: Float['b s n d'] | None = None,
@@ -5330,9 +5331,10 @@ def forward(
53305331
valid_distogram_mask = distogram_atom_indices >= 0 & valid_atom_len_mask
53315332
distogram_atom_indices = distogram_atom_indices.masked_fill(~valid_distogram_mask, 0)
53325333

5333-
valid_atom_indices_for_frame = None
53345334
if exists(atom_indices_for_frame):
5335-
valid_atom_indices_for_frame = (atom_indices_for_frame >= 0).all(dim = -1) & valid_atom_len_mask
5335+
valid_atom_indices_for_frame = default(valid_atom_indices_for_frame, torch.ones_like(molecule_atom_lens).bool())
5336+
5337+
valid_atom_indices_for_frame = valid_atom_indices_for_frame & (atom_indices_for_frame >= 0).all(dim = -1) & valid_atom_len_mask
53365338
atom_indices_for_frame = einx.where('b n, b n three, -> b n three', valid_atom_indices_for_frame, atom_indices_for_frame, 0)
53375339

53385340
assert exists(molecule_atom_lens) or exists(atom_mask)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)