Skip to content

Commit 872d736

Browse files
authored
attempt to identify and fix flaky test (#261)
1 parent 618c307 commit 872d736

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ env:
55
TYPECHECK: True
66
DEBUG: True
77
DEEPSPEED_CHECKPOINTING: False
8+
IS_GITHUB_CI: True
89

910
jobs:
1011
build:

alphafold3_pytorch/tensor_typing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def package_available(package_name: str) -> bool:
8787
else:
8888
checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant = False)
8989

90+
# check is github ci
91+
92+
IS_GITHUB_CI = env.bool('IS_GITHUB_CI', False)
93+
9094
# use env variable TYPECHECK to control whether to use beartype + jaxtyping
9195

9296
should_typecheck = env.bool('TYPECHECK', False)
@@ -115,5 +119,6 @@ def package_available(package_name: str) -> bool:
115119
should_typecheck,
116120
beartype_isinstance,
117121
checkpoint,
118-
IS_DEBUGGING
122+
IS_DEBUGGING,
123+
IS_GITHUB_CI
119124
]

tests/test_input.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
file_to_atom_input
1818
)
1919

20+
from alphafold3_pytorch.tensor_typing import IS_GITHUB_CI
21+
2022
from alphafold3_pytorch.data import mmcif_writing
2123

2224
from alphafold3_pytorch.life import (
@@ -230,6 +232,7 @@ def test_atompos_input():
230232
assert sampled_atom_pos.shape == (1, (5 + 4 + 21 + 3), 3)
231233

232234
def test_pdbinput_input():
235+
233236
"""Test the PDBInput class, particularly its input transformations for mmCIF files."""
234237
filepath = os.path.join("data", "test", "mmcifs", DATA_TEST_PDB_ID[1:3], f"{DATA_TEST_PDB_ID}-assembly1.cif")
235238
file_id = os.path.splitext(os.path.basename(filepath))[0]
@@ -242,14 +245,14 @@ def test_pdbinput_input():
242245
"contiguous_weight": 0.2,
243246
"spatial_weight": 0.4,
244247
"spatial_interface_weight": 0.4,
245-
"n_res": 64,
248+
"n_res": 4,
246249
},
247250
training=True,
248251
)
249252

250253
eval_pdb_input = PDBInput(filepath)
251254

252-
batched_atom_input = pdb_inputs_to_batched_atom_input(train_pdb_input, atoms_per_window=27)
255+
batched_atom_input = pdb_inputs_to_batched_atom_input(train_pdb_input, atoms_per_window=4)
253256

254257
# training
255258

@@ -262,12 +265,12 @@ def test_pdbinput_input():
262265
dim_token=2,
263266
dim_atom_inputs=3,
264267
dim_atompair_inputs=5,
265-
atoms_per_window=27,
268+
atoms_per_window=4,
266269
dim_template_feats=108,
267270
num_molecule_mods=4,
268271
num_dist_bins=64,
269-
num_rollout_steps=2,
270-
diffusion_num_augmentations=2,
272+
num_rollout_steps=1,
273+
diffusion_num_augmentations=1,
271274
confidence_head_kwargs=dict(pairformer_depth=1),
272275
template_embedder_kwargs=dict(pairformer_stack_depth=1),
273276
msa_module_kwargs=dict(depth=1, dim_msa=2),
@@ -288,9 +291,14 @@ def test_pdbinput_input():
288291
loss = alphafold3(**batched_atom_input.model_forward_dict())
289292
loss.backward()
290293

294+
# sampling is too much for github ci for now
295+
296+
if IS_GITHUB_CI:
297+
return
298+
291299
# sampling
292300

293-
batched_eval_atom_input = pdb_inputs_to_batched_atom_input(eval_pdb_input, atoms_per_window=27)
301+
batched_eval_atom_input = pdb_inputs_to_batched_atom_input(eval_pdb_input, atoms_per_window=4)
294302

295303
alphafold3.eval()
296304

0 commit comments

Comments
 (0)