Skip to content

Commit c59cb14

Browse files
authored
Fix cluster ID arrangement for chains (#97)
Address issue related to cluster IDs being bundled together incorrectly. Sped up initialization of WeightedPDBSampler by removing some redundant typechecks.
1 parent 5901252 commit c59cb14

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

alphafold3_pytorch/data/weighted_pdb_sampler.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from scripts.cluster_pdb_mmcifs import CLUSTERING_MOLECULE_TYPE
1010

1111

12-
@typecheck
1312
def get_chain_count(molecule_type: CLUSTERING_MOLECULE_TYPE) -> Tuple[int, int, int]:
1413
"""
1514
Returns the number of protein, nucleic acid, and ligand chains in a
@@ -31,7 +30,6 @@ def get_chain_count(molecule_type: CLUSTERING_MOLECULE_TYPE) -> Tuple[int, int,
3130
raise ValueError(f"Unknown molecule type: {molecule_type}")
3231

3332

34-
@typecheck
3533
def calculate_weight(
3634
alphas: Dict[str, float],
3735
beta: float,
@@ -84,7 +82,7 @@ def get_interface_weight(
8482
def get_cluster_sizes(
8583
mapping: pl.DataFrame,
8684
cluster_id_col: str,
87-
) -> Dict[str, int]:
85+
) -> Dict[int, int]:
8886
"""
8987
Returns a dictionary where keys are cluster IDs and values are the number
9088
of chains/interfaces in the cluster.
@@ -186,26 +184,18 @@ def __init__(
186184
if not isinstance(chain_mapping_paths, list):
187185
chain_mapping_paths = [chain_mapping_paths]
188186

189-
chain_mapping = []
190-
for path in chain_mapping_paths:
191-
mapping = pl.read_csv(path)
192-
mapping = mapping.with_columns(
193-
(pl.col("molecule_id") + "-" + pl.col("cluster_id").cast(pl.String)).alias(
194-
"cluster_id"
195-
)
187+
chain_mapping = [pl.read_csv(path) for path in chain_mapping_paths]
188+
# Increment chain cluster IDs to avoid overlap
189+
chain_cluster_nums = [
190+
mapping.get_column("cluster_id").max() for mapping in chain_mapping
191+
]
192+
for i in range(1, len(chain_mapping)):
193+
chain_mapping[i] = chain_mapping[i].with_columns(
194+
(pl.col("cluster_id") + sum(chain_cluster_nums[:i])).alias("cluster_id")
196195
)
197-
chain_mapping.append(mapping)
196+
198197
chain_mapping = pl.concat(chain_mapping)
199198
interface_mapping = pl.read_csv(interface_mapping_path)
200-
interface_mapping = interface_mapping.with_columns(
201-
(
202-
pl.col("interface_molecule_id_1")
203-
+ "-"
204-
+ pl.col("interface_molecule_id_2")
205-
+ "-"
206-
+ pl.col("interface_cluster_id").cast(pl.String)
207-
).alias("interface_cluster_id")
208-
)
209199

210200
# Filter out unwanted PDB IDs
211201
if len(pdb_ids_to_skip) > 0:
@@ -228,6 +218,20 @@ def __init__(
228218
compute_interface_weights(interface_mapping, self.alphas, self.betas["interface"]),
229219
)
230220

221+
# Add additional information to the cluster IDs
222+
chain_mapping = chain_mapping.with_columns(
223+
(pl.col("molecule_id") + "-" + pl.col("cluster_id").cast(pl.String)).alias("cluster_id")
224+
)
225+
interface_mapping = interface_mapping.with_columns(
226+
(
227+
pl.col("interface_molecule_id_1")
228+
+ "-"
229+
+ pl.col("interface_molecule_id_2")
230+
+ "-"
231+
+ pl.col("interface_cluster_id").cast(pl.String)
232+
).alias("interface_cluster_id")
233+
)
234+
231235
# Concatenate chain and interface mappings
232236
chain_mapping = chain_mapping.with_columns(
233237
[

scripts/filter_pdb_mmcifs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from operator import itemgetter
3939
from typing import Dict, List, Literal, Set, Tuple
4040
from datetime import datetime
41-
from dateutil import parser as date_parser
4241

4342
import numpy as np
4443
import timeout_decorator
@@ -140,7 +139,7 @@ def filter_pdb_release_date(
140139
"release_date" in mmcif_object.header
141140
and exists(mmcif_object.header["release_date"])
142141
and min_cutoff_date
143-
<= date_parser.parse(mmcif_object.header["release_date"])
142+
<= datetime.strptime(mmcif_object.header["release_date"], "%Y-%m-%d")
144143
<= max_cutoff_date
145144
)
146145

0 commit comments

Comments
 (0)