Skip to content

Commit be911d0

Browse files
committed
bring in joblib
1 parent 1cde46f commit be911d0

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from torch.nn.utils.rnn import pad_sequence
2424

2525
from loguru import logger
26+
from joblib import Parallel, delayed
27+
2628
from pdbeccdutils.core import ccd_reader
2729

2830
from rdkit import Chem
@@ -221,7 +223,8 @@ def pdb_dataset_to_atom_inputs(
221223
output_atom_folder: str | Path | None = None,
222224
indices: Iterable | None = None,
223225
return_atom_dataset = False,
224-
verbose = True
226+
n_jobs: int = 8,
227+
parallel_kwargs: dict = dict()
225228
) -> Path | AtomDataset:
226229

227230
if not exists(output_atom_folder):
@@ -235,26 +238,21 @@ def pdb_dataset_to_atom_inputs(
235238
if not exists(indices):
236239
indices = torch.randperm(len(pdb_dataset)).tolist()
237240

238-
indices = iter(indices)
239-
240241
to_atom_input_fn = compose(
241242
pdb_input_to_molecule_input,
242243
molecule_to_atom_input
243244
)
244245

245-
while index := next(indices, None):
246-
if not exists(index):
247-
break
248-
246+
@delayed
247+
def pdb_input_to_atom_file(index, path):
249248
pdb_input = pdb_dataset[index]
250249

251250
atom_input = to_atom_input_fn(pdb_input)
252-
atom_input_path = output_atom_folder / f'{index}.pt'
251+
atom_input_path = path / f'{index}.pt'
253252

254253
atom_input_to_file(atom_input, atom_input_path)
255254

256-
if verbose:
257-
logger.info(f'converted pdb input with index {index} to {str(atom_input_path)}')
255+
Parallel(n_jobs = n_jobs, **parallel_kwargs)(pdb_input_to_atom_file(index, output_atom_folder) for index in indices)
258256

259257
if not return_atom_dataset:
260258
return output_atom_folder

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.48"
3+
version = "0.2.49"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -30,6 +30,7 @@ dependencies = [
3030
"einx>=0.2.2",
3131
"ema-pytorch>=0.5.0",
3232
"environs",
33+
"joblib",
3334
"gemmi>=0.6.6",
3435
"frame-averaging-pytorch>=0.0.18",
3536
"huggingface_hub>=0.21.4",

0 commit comments

Comments
 (0)