Skip to content

Commit cdbc900

Browse files
authored
Fix batching of missing_atom_masks, and clean up code (#160)
* Update filter_pdb_test_mmcifs.py * Update filter_pdb_train_mmcifs.py * Update inputs.py * Update weighted_pdb_sampler.py * Update inputs.py
1 parent addf096 commit cdbc900

File tree

4 files changed

+29
-33
lines changed

4 files changed

+29
-33
lines changed

alphafold3_pytorch/data/weighted_pdb_sampler.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __init__(
193193
alpha_nuc: float = 3.0,
194194
alpha_ligand: float = 1.0,
195195
pdb_ids_to_skip: List[str] = [],
196-
subset_to_ids: list[int] | None = None,
196+
pdb_ids_to_keep: list[str] | None = None,
197197
):
198198
# Load chain and interface mappings
199199
if not isinstance(chain_mapping_paths, list):
@@ -226,28 +226,24 @@ def __init__(
226226
"Precomputing chain and interface weights. This may take several minutes to complete."
227227
)
228228

229-
# Subset to specific indices if provided
230-
if exists(subset_to_ids):
231-
chain_mapping = (
232-
chain_mapping.with_row_index()
233-
.filter(pl.col("index").is_in(subset_to_ids))
234-
.select(["pdb_id", "chain_id", "molecule_id", "cluster_id"])
229+
# Subset to specific PDB IDs if provided
230+
if exists(pdb_ids_to_keep):
231+
chain_mapping = chain_mapping.filter(pl.col("pdb_id").is_in(pdb_ids_to_keep)).select(
232+
["pdb_id", "chain_id", "molecule_id", "cluster_id"]
235233
)
236-
interface_mapping = (
237-
interface_mapping.with_row_index()
238-
.filter(pl.col("index").is_in(subset_to_ids))
239-
.select(
240-
[
241-
"pdb_id",
242-
"interface_chain_id_1",
243-
"interface_chain_id_2",
244-
"interface_molecule_id_1",
245-
"interface_molecule_id_2",
246-
"interface_chain_cluster_id_1",
247-
"interface_chain_cluster_id_2",
248-
"interface_cluster_id",
249-
]
250-
)
234+
interface_mapping = interface_mapping.filter(
235+
pl.col("pdb_id").is_in(pdb_ids_to_keep)
236+
).select(
237+
[
238+
"pdb_id",
239+
"interface_chain_id_1",
240+
"interface_chain_id_2",
241+
"interface_molecule_id_1",
242+
"interface_molecule_id_2",
243+
"interface_chain_cluster_id_1",
244+
"interface_chain_cluster_id_2",
245+
"interface_cluster_id",
246+
]
251247
)
252248

253249
chain_mapping.insert_column(

alphafold3_pytorch/inputs.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
is_polymer,
6464
)
6565
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
66-
from alphafold3_pytorch.utils.utils import default, exists, first, identity
66+
from alphafold3_pytorch.utils.utils import default, exists, first
6767

6868
# silence RDKit's warnings
6969

@@ -166,7 +166,8 @@ def inner(x, *args, **kwargs):
166166
}
167167

168168
ATOM_DEFAULT_PAD_VALUES = dict(
169-
molecule_atom_lens = 0
169+
molecule_atom_lens = 0,
170+
missing_atom_mask = True
170171
)
171172

172173
@typecheck
@@ -2678,10 +2679,6 @@ def __init__(
26782679
assert folder.exists() and folder.is_dir(), f"{str(folder)} does not exist for PDBDataset"
26792680
self.folder = folder
26802681

2681-
self.files = {
2682-
os.path.splitext(os.path.basename(file.name))[0]: file
2683-
for file in folder.glob(os.path.join("**", "*.cif"))
2684-
}
26852682
self.sampler = sampler
26862683
self.sample_type = sample_type
26872684
self.training = training
@@ -2700,9 +2697,14 @@ def __init__(
27002697
if exists(self.sampler):
27012698
sampler_pdb_ids = set(self.sampler.mappings.get_column("pdb_id").to_list())
27022699
self.files = {
2703-
file: filepath
2704-
for (file, filepath) in self.files.items()
2705-
if file in sampler_pdb_ids
2700+
os.path.splitext(os.path.basename(filepath.name))[0]: filepath
2701+
for filepath in folder.glob(os.path.join("**", "*.cif"))
2702+
if os.path.splitext(os.path.basename(filepath.name))[0] in sampler_pdb_ids
2703+
}
2704+
else:
2705+
self.files = {
2706+
os.path.splitext(os.path.basename(file.name))[0]: file
2707+
for file in folder.glob(os.path.join("**", "*.cif"))
27062708
}
27072709

27082710
if exists(sample_only_pdb_ids):

scripts/filter_pdb_test_mmcifs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import glob
2525
import os
2626
from datetime import datetime
27-
from typing import List, Tuple
2827

2928
import timeout_decorator
3029
from tqdm.contrib.concurrent import process_map

scripts/filter_pdb_train_mmcifs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import random
3838
from datetime import datetime
3939
from operator import itemgetter
40-
from typing import Dict, List, Set, Tuple
4140

4241
import numpy as np
4342
import timeout_decorator

0 commit comments

Comments
 (0)