Skip to content

Commit 6811dfd

Browse files
committed
fix to tests
1 parent f0de8b2 commit 6811dfd

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3071,9 +3071,7 @@ def __init__(
30713071
# residue or nucleotide modifications
30723072

30733073
num_molecule_mods = default(num_molecule_mods, 0)
3074-
30753074
has_molecule_mod_embeds = num_molecule_mods > 0
3076-
self.num_molecule_mods = num_molecule_mods
30773075

30783076
if has_molecule_mod_embeds:
30793077
self.molecule_mod_embeds = nn.Embedding(num_molecule_mods, dim_single)
@@ -3233,6 +3231,11 @@ def __init__(
32333231

32343232
self.register_buffer('zero', torch.tensor(0.), persistent = False)
32353233

3234+
# some shorthand for jaxtyping
3235+
3236+
self.dapi = self.dim_atompair_inputs
3237+
self.dai = self.dim_atom_inputs
3238+
32363239
@property
32373240
def device(self):
32383241
return self.zero.device
@@ -3309,16 +3312,16 @@ def init_and_load(
33093312
def forward(
33103313
self,
33113314
*,
3312-
atom_inputs: Float['b m dai'],
3313-
atompair_inputs: Float['b m m dapi'] | Float['b nw w1 w2 dapi'],
3315+
atom_inputs: Float['b m {self.dai}'],
3316+
atompair_inputs: Float['b m m {self.dapi}'] | Float['b nw w1 w2 {self.dapi}'],
33143317
additional_molecule_feats: Int[f'b n {ADDITIONAL_MOLECULE_FEATS}'],
33153318
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}'],
33163319
molecule_atom_lens: Int['b n'],
33173320
molecule_ids: Int['b n'],
33183321
additional_token_feats: Float['b n {self.dim_additional_token_feats}'] | None = None,
33193322
atom_ids: Int['b m'] | None = None,
33203323
atompair_ids: Int['b m m'] | Int['b nw w1 w2'] | None = None,
3321-
is_molecule_mod: Bool['b n {self.num_molecule_mods}'] | None = None,
3324+
is_molecule_mod: Bool['b n num_mods'] | None = None,
33223325
atom_mask: Bool['b m'] | None = None,
33233326
atom_parent_ids: Int['b m'] | None = None,
33243327
token_bonds: Bool['b n n'] | None = None,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.110"
3+
version = "0.1.111"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def test_alphafold3(
473473
atoms_per_window = atoms_per_window,
474474
dim_template_feats = 44,
475475
num_dist_bins = 38,
476+
num_molecule_mods = num_molecule_mods,
476477
confidence_head_kwargs = dict(
477478
pairformer_depth = 1
478479
),

0 commit comments

Comments
 (0)