@@ -226,10 +226,10 @@ def populate_mock_pdb_and_remove_test_folders():
226226 for i in range (100 ):
227227 shutil .copy2 (str (working_cif_file ), str (train_folder / f'{ i } .cif' ))
228228
229- for i in range (4 ):
229+ for i in range (1 ):
230230 shutil .copy2 (str (working_cif_file ), str (valid_folder / f'{ i } .cif' ))
231231
232- for i in range (2 ):
232+ for i in range (1 ):
233233 shutil .copy2 (str (working_cif_file ), str (test_folder / f'{ i } .cif' ))
234234
235235 yield
@@ -239,25 +239,41 @@ def populate_mock_pdb_and_remove_test_folders():
239239def test_trainer_with_pdb_input (populate_mock_pdb_and_remove_test_folders ):
240240
241241 alphafold3 = Alphafold3 (
242- dim_atom = 8 ,
243- dim_atompair = 8 ,
244- dim_input_embedder_token = 8 ,
245- dim_single = 8 ,
246- dim_pairwise = 8 ,
247- dim_token = 8 ,
242+ dim_atom = 4 ,
243+ dim_atompair = 4 ,
244+ dim_input_embedder_token = 4 ,
245+ dim_single = 4 ,
246+ dim_pairwise = 4 ,
247+ dim_token = 4 ,
248248 dim_atom_inputs = 3 ,
249249 dim_atompair_inputs = 1 ,
250250 atoms_per_window = 27 ,
251251 dim_template_feats = 44 ,
252252 num_dist_bins = 38 ,
253- confidence_head_kwargs = dict (pairformer_depth = 1 ),
253+ confidence_head_kwargs = dict (
254+ pairformer_depth = 1 ,
255+ ),
254256 template_embedder_kwargs = dict (pairformer_stack_depth = 1 ),
255257 msa_module_kwargs = dict (depth = 1 ),
256- pairformer_stack = dict (depth = 1 ),
258+ pairformer_stack = dict (
259+ depth = 1 ,
260+ pair_bias_attn_dim_head = 4 ,
261+ pair_bias_attn_heads = 2 ,
262+ ),
257263 diffusion_module_kwargs = dict (
258264 atom_encoder_depth = 1 ,
259265 token_transformer_depth = 1 ,
260266 atom_decoder_depth = 1 ,
267+ atom_decoder_kwargs = dict (
268+ attn_pair_bias_kwargs = dict (
269+ dim_head = 4
270+ )
271+ ),
272+ atom_encoder_kwargs = dict (
273+ attn_pair_bias_kwargs = dict (
274+ dim_head = 4
275+ )
276+ )
261277 ),
262278 )
263279
@@ -267,7 +283,7 @@ def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
267283
268284 # test saving and loading from Alphafold3, independent of lightning
269285
270- dataloader = DataLoader (dataset , batch_size = 2 )
286+ dataloader = DataLoader (dataset , batch_size = 1 )
271287 inputs = next (iter (dataloader ))
272288
273289 alphafold3 .eval ()
@@ -298,7 +314,7 @@ def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
298314 num_train_steps = 2 ,
299315 batch_size = 1 ,
300316 valid_every = 1 ,
301- grad_accum_every = 2 ,
317+ grad_accum_every = 1 ,
302318 checkpoint_every = 1 ,
303319 checkpoint_folder = './test-folder/checkpoints' ,
304320 overwrite_checkpoints = True ,
0 commit comments