|
12 | 12 |
|
13 | 13 | from alphafold3_pytorch import ( |
14 | 14 | Alphafold3, |
| 15 | + PDBDataset, |
15 | 16 | AtomInput, |
16 | 17 | DataLoader, |
17 | 18 | Trainer, |
@@ -106,7 +107,7 @@ def remove_test_folders(): |
106 | 107 | yield |
107 | 108 | shutil.rmtree('./test-folder') |
108 | 109 |
|
109 | | -def test_trainer(remove_test_folders): |
| 110 | +def test_trainer_with_mock_atom_input(remove_test_folders): |
110 | 111 |
|
111 | 112 | alphafold3 = Alphafold3( |
112 | 113 | dim_atom_inputs = 77, |
@@ -204,6 +205,135 @@ def test_trainer(remove_test_folders): |
204 | 205 |
|
205 | 206 | alphafold3 = Alphafold3.init_and_load('./test-folder/nested/folder2/training.pt') |
206 | 207 |
|
| 208 | +# testing trainer with pdb inputs |
| 209 | + |
| 210 | +@pytest.fixture() |
| 211 | +def populate_mock_pdb_and_remove_test_folders(): |
| 212 | + proj_root = Path('.') |
| 213 | + working_cif_file = proj_root / 'data' / 'test' / '7a4d-assembly1.cif' |
| 214 | + |
| 215 | + pytest_root_folder = Path('./test-folder') |
| 216 | + data_folder = pytest_root_folder / 'data' |
| 217 | + |
| 218 | + train_folder = data_folder / 'train' |
| 219 | + valid_folder = data_folder / 'valid' |
| 220 | + test_folder = data_folder / 'test' |
| 221 | + |
| 222 | + train_folder.mkdir(exist_ok = True, parents = True) |
| 223 | + valid_folder.mkdir(exist_ok = True, parents = True) |
| 224 | + test_folder.mkdir(exist_ok = True, parents = True) |
| 225 | + |
| 226 | + for i in range(100): |
| 227 | + shutil.copy2(str(working_cif_file), str(train_folder / f'{i}.cif')) |
| 228 | + |
| 229 | + for i in range(4): |
| 230 | + shutil.copy2(str(working_cif_file), str(valid_folder / f'{i}.cif')) |
| 231 | + |
| 232 | + for i in range(2): |
| 233 | + shutil.copy2(str(working_cif_file), str(test_folder / f'{i}.cif')) |
| 234 | + |
| 235 | + yield |
| 236 | + |
| 237 | + shutil.rmtree('./test-folder') |
| 238 | + |
| 239 | +def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders): |
| 240 | + |
| 241 | + alphafold3 = Alphafold3( |
| 242 | + dim_atom=8, |
| 243 | + dim_atompair=8, |
| 244 | + dim_input_embedder_token=8, |
| 245 | + dim_single=8, |
| 246 | + dim_pairwise=8, |
| 247 | + dim_token=8, |
| 248 | + dim_atom_inputs=3, |
| 249 | + dim_atompair_inputs=1, |
| 250 | + atoms_per_window=27, |
| 251 | + dim_template_feats=44, |
| 252 | + num_dist_bins=38, |
| 253 | + confidence_head_kwargs=dict(pairformer_depth=1), |
| 254 | + template_embedder_kwargs=dict(pairformer_stack_depth=1), |
| 255 | + msa_module_kwargs=dict(depth=1), |
| 256 | + pairformer_stack=dict(depth=1), |
| 257 | + diffusion_module_kwargs=dict( |
| 258 | + atom_encoder_depth=1, |
| 259 | + token_transformer_depth=1, |
| 260 | + atom_decoder_depth=1, |
| 261 | + ), |
| 262 | + ) |
| 263 | + |
| 264 | + dataset = PDBDataset('./test-folder/data/train') |
| 265 | + valid_dataset = PDBDataset('./test-folder/data/valid') |
| 266 | + test_dataset = PDBDataset('./test-folder/data/test') |
| 267 | + |
| 268 | + # test saving and loading from Alphafold3, independent of lightning |
| 269 | + |
| 270 | + dataloader = DataLoader(dataset, batch_size = 2) |
| 271 | + inputs = next(iter(dataloader)) |
| 272 | + |
| 273 | + alphafold3.eval() |
| 274 | + _, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True) |
| 275 | + before_distogram = breakdown.distogram |
| 276 | + |
| 277 | + path = './test-folder/nested/folder/af3' |
| 278 | + alphafold3.save(path, overwrite = True) |
| 279 | + |
| 280 | + # load from scratch, along with saved hyperparameters |
| 281 | + |
| 282 | + alphafold3 = Alphafold3.init_and_load(path) |
| 283 | + |
| 284 | + alphafold3.eval() |
| 285 | + _, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True) |
| 286 | + after_distogram = breakdown.distogram |
| 287 | + |
| 288 | + assert torch.allclose(before_distogram, after_distogram) |
| 289 | + |
| 290 | + # test training + validation |
| 291 | + |
| 292 | + trainer = Trainer( |
| 293 | + alphafold3, |
| 294 | + dataset = dataset, |
| 295 | + valid_dataset = valid_dataset, |
| 296 | + test_dataset = test_dataset, |
| 297 | + accelerator = 'cpu', |
| 298 | + num_train_steps = 2, |
| 299 | + batch_size = 1, |
| 300 | + valid_every = 1, |
| 301 | + grad_accum_every = 2, |
| 302 | + checkpoint_every = 1, |
| 303 | + checkpoint_folder = './test-folder/checkpoints', |
| 304 | + overwrite_checkpoints = True, |
| 305 | + ema_kwargs = dict( |
| 306 | + use_foreach = True, |
| 307 | + update_after_step = 0, |
| 308 | + update_every = 1 |
| 309 | + ) |
| 310 | + ) |
| 311 | + |
| 312 | + trainer() |
| 313 | + |
| 314 | + # assert checkpoints created |
| 315 | + |
| 316 | + assert Path(f'./test-folder/checkpoints/({trainer.train_id})_af3.ckpt.1.pt').exists() |
| 317 | + |
| 318 | + # assert can load latest checkpoint by loading from a directory |
| 319 | + |
| 320 | + trainer.load('./test-folder/checkpoints', strict = False) |
| 321 | + |
| 322 | + assert exists(trainer.model_loaded_from_path) |
| 323 | + |
| 324 | + # saving and loading from trainer |
| 325 | + |
| 326 | + trainer.save('./test-folder/nested/folder2/training.pt', overwrite = True) |
| 327 | + trainer.load('./test-folder/nested/folder2/training.pt', strict = False) |
| 328 | + |
| 329 | + # allow for only loading model, needed for fine-tuning logic |
| 330 | + |
| 331 | + trainer.load('./test-folder/nested/folder2/training.pt', only_model = True, strict = False) |
| 332 | + |
| 333 | + # also allow for loading Alphafold3 directly from training ckpt |
| 334 | + |
| 335 | + alphafold3 = Alphafold3.init_and_load('./test-folder/nested/folder2/training.pt') |
| 336 | + |
207 | 337 | # test use of collation fn outside of trainer |
208 | 338 |
|
209 | 339 | def test_collate_fn(): |
|
0 commit comments