1717 file_to_atom_input
1818)
1919
20+ from alphafold3_pytorch .tensor_typing import IS_GITHUB_CI
21+
2022from alphafold3_pytorch .data import mmcif_writing
2123
2224from alphafold3_pytorch .life import (
@@ -230,6 +232,7 @@ def test_atompos_input():
230232 assert sampled_atom_pos .shape == (1 , (5 + 4 + 21 + 3 ), 3 )
231233
232234def test_pdbinput_input ():
235+
233236 """Test the PDBInput class, particularly its input transformations for mmCIF files."""
234237 filepath = os .path .join ("data" , "test" , "mmcifs" , DATA_TEST_PDB_ID [1 :3 ], f"{ DATA_TEST_PDB_ID } -assembly1.cif" )
235238 file_id = os .path .splitext (os .path .basename (filepath ))[0 ]
@@ -242,14 +245,14 @@ def test_pdbinput_input():
242245 "contiguous_weight" : 0.2 ,
243246 "spatial_weight" : 0.4 ,
244247 "spatial_interface_weight" : 0.4 ,
245- "n_res" : 64 ,
248+ "n_res" : 4 ,
246249 },
247250 training = True ,
248251 )
249252
250253 eval_pdb_input = PDBInput (filepath )
251254
252- batched_atom_input = pdb_inputs_to_batched_atom_input (train_pdb_input , atoms_per_window = 27 )
255+ batched_atom_input = pdb_inputs_to_batched_atom_input (train_pdb_input , atoms_per_window = 4 )
253256
254257 # training
255258
@@ -262,12 +265,12 @@ def test_pdbinput_input():
262265 dim_token = 2 ,
263266 dim_atom_inputs = 3 ,
264267 dim_atompair_inputs = 5 ,
265- atoms_per_window = 27 ,
268+ atoms_per_window = 4 ,
266269 dim_template_feats = 108 ,
267270 num_molecule_mods = 4 ,
268271 num_dist_bins = 64 ,
269- num_rollout_steps = 2 ,
270- diffusion_num_augmentations = 2 ,
272+ num_rollout_steps = 1 ,
273+ diffusion_num_augmentations = 1 ,
271274 confidence_head_kwargs = dict (pairformer_depth = 1 ),
272275 template_embedder_kwargs = dict (pairformer_stack_depth = 1 ),
273276 msa_module_kwargs = dict (depth = 1 , dim_msa = 2 ),
@@ -288,9 +291,14 @@ def test_pdbinput_input():
288291 loss = alphafold3 (** batched_atom_input .model_forward_dict ())
289292 loss .backward ()
290293
294+ # sampling is too much for github ci for now
295+
296+ if IS_GITHUB_CI :
297+ return
298+
291299 # sampling
292300
293- batched_eval_atom_input = pdb_inputs_to_batched_atom_input (eval_pdb_input , atoms_per_window = 27 )
301+ batched_eval_atom_input = pdb_inputs_to_batched_atom_input (eval_pdb_input , atoms_per_window = 4 )
294302
295303 alphafold3 .eval ()
296304
0 commit comments