Skip to content

Commit d416481

Browse files
committed
able to set n_jobs for joblib from trainer yaml
1 parent be911d0 commit d416481

File tree

3 files changed

+16
-24
lines changed

3 files changed

+16
-24
lines changed

alphafold3_pytorch/configs.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class DatasetConfig(BaseModelWithExtra):
147147
valid_folder: DirectoryPath | None = None
148148
test_folder: DirectoryPath | None = None
149149
convert_pdb_to_atom: bool = False
150+
pdb_to_atom_kwargs: dict = dict()
150151
train_weighted_sampler: WeightedPDBSamplerConfig | None = None
151152
kwargs: dict = dict()
152153

@@ -222,6 +223,7 @@ def create_instance(
222223
dataset_kwargs = dataset_config.kwargs
223224

224225
convert_pdb_to_atom = dataset_config.convert_pdb_to_atom
226+
pdb_to_atom_kwargs = dataset_config.pdb_to_atom_kwargs
225227

226228
if convert_pdb_to_atom:
227229
assert dataset_type == 'pdb', 'must be `pdb` dataset_type if `convert_pdb_to_atom` is set to True'
@@ -233,35 +235,22 @@ def create_instance(
233235
else:
234236
raise ValueError(f'unhandled dataset_type {dataset_type}')
235237

236-
train_folder, valid_folder, test_folder = tuple(getattr(dataset_config, key, None) for key in ('train_folder', 'valid_folder', 'test_folder'))
238+
# create dataset for train, valid, and test
237239

238-
if exists(train_folder):
239-
assert 'dataset' not in trainer_kwargs
240+
for trainer_kwarg_key, config_key in (('dataset', 'train_folder'), ('valid_dataset', 'valid_folder'), ('test_dataset', 'test_folder')):
241+
folder = getattr(dataset_config, config_key, None)
240242

241-
dataset = dataset_klass(train_folder, **dataset_kwargs)
243+
if not exists(folder):
244+
continue
242245

243-
if convert_pdb_to_atom:
244-
dataset = pdb_dataset_to_atom_inputs(dataset, return_atom_dataset = True)
245-
246-
trainer_kwargs.update(dataset = dataset)
247-
248-
if exists(valid_folder):
249-
assert 'valid_dataset' not in trainer_kwargs
250-
dataset = dataset_klass(valid_folder, **dataset_kwargs)
251-
252-
if convert_pdb_to_atom:
253-
dataset = pdb_dataset_to_atom_inputs(dataset, return_atom_dataset = True)
254-
255-
trainer_kwargs.update(valid_dataset = dataset)
246+
assert trainer_kwarg_key not in trainer_kwargs
256247

257-
if exists(test_folder):
258-
assert 'test_dataset' not in trainer_kwargs
259-
dataset = dataset_klass(test_folder, **dataset_kwargs)
248+
dataset = dataset_klass(folder, **dataset_kwargs)
260249

261250
if convert_pdb_to_atom:
262-
dataset = pdb_dataset_to_atom_inputs(dataset, return_atom_dataset = True)
251+
dataset = pdb_dataset_to_atom_inputs(dataset, return_atom_dataset = True, **pdb_to_atom_kwargs)
263252

264-
trainer_kwargs.update(test_dataset = dataset)
253+
trainer_kwargs.update(**{trainer_kwarg_key: dataset})
265254

266255
# handle weighted pdb sampling
267256

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.49"
3+
version = "0.2.50"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/configs/trainer_with_atom_dataset_created_from_pdb.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,8 @@ checkpoint_folder: ./checkpoints
5757
overwrite_checkpoints: false
5858
dataset_config:
5959
dataset_type: pdb
60-
convert_pdb_to_atom: true
6160
train_folder: ./test-folder/data/train
61+
convert_pdb_to_atom: true
62+
pdb_to_atom_kwargs:
63+
n_jobs: 16
64+

0 commit comments

Comments
 (0)