Skip to content

Commit a70c6d3

Browse files
authored
switch to pytest-shard (#137)
switch to pytest-shard
1 parent 099279a commit a70c6d3

File tree

12 files changed

+18
-12
lines changed

12 files changed

+18
-12
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
fail-fast: false
1414
matrix:
15-
group: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
15+
group: [0, 1, 2, 3, 4]
1616

1717
steps:
1818
- uses: actions/checkout@v4
@@ -27,4 +27,4 @@ jobs:
2727
python -m pip install -e .[test]
2828
- name: Test with pytest
2929
run: |
30-
python -m pytest --test-group-count 10 --test-group-random-seed 42 --test-group ${{ matrix.group }} tests/
30+
python -m pytest --num-shards 5 --shard-id ${{ matrix.group }} tests/

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
7777
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
7878
additional_token_feats = torch.randn(2, seq_len, 2)
7979
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
80+
is_molecule_mod = torch.randint(0, 2, (2, seq_len, 4)).bool()
8081
molecule_ids = torch.randint(0, 32, (2, seq_len))
8182

8283
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -107,6 +108,7 @@ loss = alphafold3(
107108
additional_molecule_feats = additional_molecule_feats,
108109
additional_token_feats = additional_token_feats,
109110
is_molecule_types = is_molecule_types,
111+
is_molecule_mod = is_molecule_mod,
110112
msa = msa,
111113
msa_mask = msa_mask,
112114
templates = template_feats,
@@ -134,6 +136,7 @@ sampled_atom_pos = alphafold3(
134136
additional_molecule_feats = additional_molecule_feats,
135137
additional_token_feats = additional_token_feats,
136138
is_molecule_types = is_molecule_types,
139+
is_molecule_mod = is_molecule_mod,
137140
msa = msa,
138141
msa_mask = msa_mask,
139142
templates = template_feats,
@@ -180,6 +183,7 @@ alphafold3 = Alphafold3(
180183
atoms_per_window = 27,
181184
dim_template_feats = 44,
182185
num_dist_bins = 38,
186+
num_molecule_mods = 0,
183187
confidence_head_kwargs = dict(
184188
pairformer_depth = 1
185189
),

alphafold3_pytorch/alphafold3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4431,7 +4431,7 @@ def __init__(
44314431
num_molecule_types: int = NUM_MOLECULE_IDS, # restype in additional residue information, apparently 32. will do 33 to account for metal ions
44324432
num_atom_embeds: int | None = None,
44334433
num_atompair_embeds: int | None = None,
4434-
num_molecule_mods: int | None = None,
4434+
num_molecule_mods: int | None = DEFAULT_NUM_MOLECULE_MODS,
44354435
distance_bins: List[float] = torch.linspace(3, 20, 38).float().tolist(),
44364436
ignore_index = -1,
44374437
num_dist_bins: int | None = None,

alphafold3_pytorch/inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial, wraps
88
from itertools import groupby
99
from collections import defaultdict
10-
from collections.abc import Iterableassignment
10+
from collections.abc import Iterable
1111
from dataclasses import asdict, dataclass, field
1212
from typing import Any, Callable, Dict, List, Literal, Set, Tuple, Type
1313

alphafold3_pytorch/mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
data_length,
1919
max_seq_len = 16,
2020
atoms_per_window = 4,
21-
has_molecule_mods = False
21+
has_molecule_mods = True
2222
):
2323
self.data_length = data_length
2424
self.max_seq_len = max_seq_len

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.96"
3+
version = "0.2.97"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -62,7 +62,7 @@ Repository = "https://github.com/lucidrains/alphafold3-pytorch"
6262
examples = []
6363
test = [
6464
"pytest",
65-
"pytest-split-tests",
65+
"pytest-shard",
6666
]
6767

6868
[build-system]

tests/configs/trainer_with_atom_dataset_created_from_pdb.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ model:
1717
num_plddt_bins: 50
1818
num_pde_bins: 64
1919
num_pae_bins: 64
20-
num_molecule_mods: 1
2120
sigma_data: 16
2221
diffusion_num_augmentations: 4
2322
loss_confidence_weight: 0.0001

tests/configs/trainer_with_pdb_dataset.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ model:
1717
num_plddt_bins: 50
1818
num_pde_bins: 64
1919
num_pae_bins: 64
20-
num_molecule_mods: 1
2120
sigma_data: 16
2221
diffusion_num_augmentations: 4
2322
loss_confidence_weight: 0.0001

tests/configs/trainer_with_pdb_dataset_and_weighted_sampling.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ model:
1717
num_plddt_bins: 50
1818
num_pde_bins: 64
1919
num_pae_bins: 64
20-
num_molecule_mods: 1
2120
sigma_data: 16
2221
diffusion_num_augmentations: 4
2322
loss_confidence_weight: 0.0001

tests/test_af3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,7 @@ def test_alphafold3_without_msa_and_templates():
689689
dim_atom_inputs = 77,
690690
dim_template_feats = 44,
691691
num_dist_bins = 38,
692+
num_molecule_mods = 0,
692693
checkpoint_trunk_pairformer = True,
693694
checkpoint_diffusion_token_transformer = True,
694695
confidence_head_kwargs = dict(
@@ -767,6 +768,7 @@ def test_alphafold3_force_return_loss():
767768
dim_atom_inputs = 77,
768769
dim_template_feats = 44,
769770
num_dist_bins = 38,
771+
num_molecule_mods = 0,
770772
confidence_head_kwargs = dict(
771773
pairformer_depth = 1
772774
),
@@ -851,6 +853,7 @@ def test_alphafold3_force_return_loss_with_confidence_logits():
851853
dim_atom_inputs = 77,
852854
dim_template_feats = 44,
853855
num_dist_bins = 38,
856+
num_molecule_mods = 0,
854857
confidence_head_kwargs = dict(
855858
pairformer_depth = 1
856859
),
@@ -913,6 +916,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
913916
alphafold3 = Alphafold3(
914917
num_atom_embeds = 7,
915918
num_atompair_embeds = 3,
919+
num_molecule_mods = 0,
916920
dim_atom_inputs = 77,
917921
dim_template_feats = 44
918922
)

0 commit comments

Comments
 (0)