Skip to content

Commit 375fd5e

Browse files
committed
makes more sense the other way around, and also make sure valid and test pdbs can be converted to atom inputs on the fly
1 parent 47dfeb9 commit 375fd5e

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

alphafold3_pytorch/configs.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def create_instance(
224224
convert_pdb_to_atom = dataset_config.convert_pdb_to_atom
225225

226226
if convert_pdb_to_atom:
227-
assert dataset_type == 'atom', 'must be `atom` dataset_type if `convert_pdb_to_atom` is set to True'
227+
assert dataset_type == 'pdb', 'must be `pdb` dataset_type if `convert_pdb_to_atom` is set to True'
228228

229229
if dataset_type == 'pdb':
230230
dataset_klass = PDBDataset
@@ -238,21 +238,29 @@ def create_instance(
238238
if exists(train_folder):
239239
assert 'dataset' not in trainer_kwargs
240240

241+
dataset = dataset_klass(train_folder, **dataset_kwargs)
242+
241243
if convert_pdb_to_atom:
242-
pdb_dataset = PDBDataset(train_folder, **dataset_kwargs)
243-
train_folder = pdb_dataset_to_atom_inputs(pdb_dataset)
244+
dataset = pdb_dataset_to_atom_inputs(dataset, return_atom_dataset = True)
244245

245-
dataset = dataset_klass(train_folder, **dataset_kwargs)
246246
trainer_kwargs.update(dataset = dataset)
247247

248248
if exists(valid_folder):
249249
assert 'valid_dataset' not in trainer_kwargs
250250
dataset = dataset_klass(valid_folder, **dataset_kwargs)
251+
252+
if convert_pdb_to_atom:
253+
dataset = pdb_dataset_to_atom_inputs(dataset, return_atom_dataset = True)
254+
251255
trainer_kwargs.update(valid_dataset = dataset)
252256

253257
if exists(test_folder):
254258
assert 'test_dataset' not in trainer_kwargs
255259
dataset = dataset_klass(test_folder, **dataset_kwargs)
260+
261+
if convert_pdb_to_atom:
262+
dataset = pdb_dataset_to_atom_inputs(dataset, return_atom_dataset = True)
263+
256264
trainer_kwargs.update(test_dataset = dataset)
257265

258266
# handle weighted pdb sampling

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.44"
3+
version = "0.2.45"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/configs/trainer_with_atom_dataset_created_from_pdb.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ checkpoint_every: 1
5656
checkpoint_folder: ./checkpoints
5757
overwrite_checkpoints: false
5858
dataset_config:
59-
dataset_type: atom
59+
dataset_type: pdb
6060
convert_pdb_to_atom: true
6161
train_folder: ./test-folder/data/train

0 commit comments

Comments
 (0)