Skip to content

Commit 5169c53

Browse files
committed
whether using packed or unpacked atom repr must be explicitly given to diffusion module
1 parent b4b1ed5 commit 5169c53

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,7 @@ def __init__(
15661566
atom_encoder_kwargs: dict = dict(),
15671567
atom_decoder_kwargs: dict = dict(),
15681568
token_transformer_kwargs: dict = dict(),
1569+
packed_atom_repr = True,
15691570
use_linear_attn = False,
15701571
linear_attn_kwargs: dict = dict(
15711572
heads = 8,
@@ -1574,6 +1575,7 @@ def __init__(
15741575
):
15751576
super().__init__()
15761577

1578+
self.packed_atom_repr = packed_atom_repr
15771579
self.atoms_per_window = atoms_per_window
15781580

15791581
# conditioning
@@ -1696,9 +1698,9 @@ def forward(
16961698
residue_atom_lens: Int['b n'] | None = None
16971699
):
16981700
w = self.atoms_per_window
1699-
is_unpacked_repr = exists(w)
1701+
is_unpacked_repr = not self.packed_atom_repr
17001702

1701-
if not is_unpacked_repr:
1703+
if self.packed_atom_repr:
17021704
assert exists(residue_atom_lens)
17031705

17041706
# in the paper, it seems they pack the atom feats
@@ -2067,7 +2069,7 @@ def forward(
20672069

20682070
if exists(additional_residue_feats):
20692071
w = self.net.atoms_per_window
2070-
is_unpacked_repr = exists(w)
2072+
is_unpacked_repr = not self.net.packed_atom_repr
20712073

20722074
is_nucleotide_or_ligand_fields = (additional_residue_feats[..., 7:] != 0.).unbind(dim = -1)
20732075

@@ -2807,9 +2809,6 @@ def __init__(
28072809

28082810
# atoms per window if using unpacked representation
28092811

2810-
if packed_atom_repr:
2811-
atoms_per_window = None
2812-
28132812
self.atoms_per_window = atoms_per_window
28142813

28152814
# augmentation
@@ -2892,6 +2891,7 @@ def __init__(
28922891
self.diffusion_module = DiffusionModule(
28932892
dim_pairwise_trunk = dim_pairwise,
28942893
dim_pairwise_rel_pos_feats = dim_pairwise,
2894+
packed_atom_repr = packed_atom_repr,
28952895
atoms_per_window = atoms_per_window,
28962896
dim_pairwise = dim_pairwise,
28972897
sigma_data = sigma_data,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.63"
3+
version = "0.0.64"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def test_diffusion_module():
241241

242242
diffusion_module = DiffusionModule(
243243
atoms_per_window = 27,
244+
packed_atom_repr = False,
244245
dim_pairwise_trunk = 128,
245246
dim_pairwise_rel_pos_feats = 12,
246247
atom_encoder_depth = 1,

0 commit comments

Comments
 (0)