Skip to content

Commit 458c4fd

Browse files
authored
Add WeightedSamplerPDB (#93)
* Add WeightedSamplerPDB Introduced a sampler that implements the algorithm described in Section 2.5.1 of the AlphaFold 3 supplement. * Convert Pandas code to Polars * Added test for weighted sampling code, changed name to `weighted_pdb_dataset.py` * Add example `*_cluster_mapping.csv` files for testing
1 parent f29d51f commit 458c4fd

11 files changed

+540
-68
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,5 @@ cython_debug/
162162
#.idea/
163163

164164
# alphafold3-pytorch
165-
/data/
165+
/data/*
166+
!/data/test/
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
from typing import Iterator, Union
2+
from torch.utils.data import Sampler
3+
import polars as pl
4+
import numpy as np
5+
6+
7+
def get_chain_count(molecule_type) -> tuple[int, int, int]:
8+
"""
9+
Returns the number of protein, nucleic acid, and ligand chains in a
10+
molecule based on its type.
11+
12+
Example:
13+
n_prot, n_nuc, n_ligand = get_chain_count("protein")
14+
"""
15+
match molecule_type:
16+
case "protein":
17+
return 1, 0, 0
18+
case "nucleic_acid":
19+
return 0, 1, 0
20+
case "ligand":
21+
return 0, 0, 1
22+
case "peptide":
23+
return 1, 0, 0
24+
case _:
25+
raise ValueError(f"Unknown molecule type: {molecule_type}")
26+
27+
28+
def calculate_weight(alphas, beta, n_prot, n_nuc, n_ligand, cluster_size) -> float:
29+
"""
30+
Calculates the weight of a chain or an interface according to the formula
31+
provided in Section 2.5.1 of the AlphaFold3 supplementary information.
32+
"""
33+
return (beta / cluster_size) * (
34+
alphas["prot"] * n_prot + alphas["nuc"] * n_nuc + alphas["ligand"] * n_ligand
35+
)
36+
37+
38+
def get_chain_weight(molecule_type, cluster_size, alphas, beta) -> float:
39+
n_prot, n_nuc, n_ligand = get_chain_count(molecule_type)
40+
return calculate_weight(alphas, beta, n_prot, n_nuc, n_ligand, cluster_size)
41+
42+
43+
def get_interface_weight(
44+
molecule_type_1, molecule_type_2, cluster_size, alphas, beta
45+
) -> float:
46+
p1, n1, l1 = get_chain_count(molecule_type_1)
47+
p2, n2, l2 = get_chain_count(molecule_type_2)
48+
49+
n_prot = p1 + p2
50+
n_nuc = n1 + n2
51+
n_ligand = l1 + l2
52+
53+
return calculate_weight(alphas, beta, n_prot, n_nuc, n_ligand, cluster_size)
54+
55+
56+
def get_cluster_sizes(mapping, cluster_id_col) -> dict[int, int]:
57+
"""
58+
Returns a dictionary where keys are cluster IDs and values are the number
59+
of chains/interfaces in the cluster.
60+
"""
61+
cluster_sizes = mapping.group_by(cluster_id_col).agg(pl.len()).sort(cluster_id_col)
62+
return {row[0]: row[1] for row in cluster_sizes.iter_rows()}
63+
64+
65+
def compute_chain_weights(chains: pl.DataFrame, alphas, beta) -> pl.Series:
66+
molecule_idx = chains.get_column_index("molecule_id")
67+
cluster_idx = chains.get_column_index("cluster_id")
68+
cluster_sizes = get_cluster_sizes(chains, "cluster_id")
69+
70+
return (
71+
chains.map_rows(
72+
lambda row: get_chain_weight(
73+
row[molecule_idx].split("-")[0],
74+
cluster_sizes[row[cluster_idx]],
75+
alphas,
76+
beta,
77+
),
78+
return_dtype=pl.Float32,
79+
)
80+
.to_series(0)
81+
.rename("weight")
82+
)
83+
84+
85+
def compute_interface_weights(interfaces: pl.DataFrame, alphas, beta) -> pl.Series:
86+
molecule_idx_1 = interfaces.get_column_index("interface_molecule_id_1")
87+
molecule_idx_2 = interfaces.get_column_index("interface_molecule_id_2")
88+
cluster_idx = interfaces.get_column_index("interface_cluster_id")
89+
cluster_sizes = get_cluster_sizes(interfaces, "interface_cluster_id")
90+
91+
return (
92+
interfaces.map_rows(
93+
lambda row: get_interface_weight(
94+
row[molecule_idx_1].split("-")[0],
95+
row[molecule_idx_2].split("-")[0],
96+
cluster_sizes[row[cluster_idx]],
97+
alphas,
98+
beta,
99+
),
100+
return_dtype=pl.Float32,
101+
)
102+
.to_series(0)
103+
.rename("weight")
104+
)
105+
106+
107+
class WeightedSamplerPDB(Sampler[list[str]]):
108+
def __init__(
109+
self,
110+
chain_mapping_paths: Union[str, list[str]],
111+
interface_mapping_path: str,
112+
batch_size: int,
113+
beta_chain: float = 0.5,
114+
beta_interface: float = 1.0,
115+
alpha_prot: float = 3.0,
116+
alpha_nuc: float = 3.0,
117+
alpha_ligand: float = 1.0,
118+
pdb_ids_to_skip: list[str] = [],
119+
):
120+
"""
121+
Initializes a dataset for weighted sampling of PDB IDs.
122+
123+
Args
124+
-------
125+
chain_mapping_paths (Union[str, list[str]])
126+
Path to the CSV file containing chain cluster
127+
mappings. If multiple paths are provided, they will be
128+
concatenated.
129+
interface_mapping_path (str)
130+
Path to the CSV file containing interface
131+
cluster mappings.
132+
batch_size (int)
133+
Number of PDB IDs to sample in each batch.
134+
beta_chain (float)
135+
Weighting factor for chain clusters.
136+
beta_interface (float)
137+
Weighting factor for interface clusters.
138+
alpha_prot (float)
139+
Weighting factor for protein chains.
140+
alpha_nuc (float)
141+
Weighting factor for nucleic acid chains.
142+
alpha_ligand (float)
143+
Weighting factor for ligand chains.
144+
pdb_ids_to_skip (list[str])
145+
List of PDB IDs to skip during sampling.
146+
Allow extra data filtering to ensure we avoid training
147+
on anomolous complexes that passed through all filtering
148+
and clustering steps.
149+
150+
Example
151+
-------
152+
```
153+
sampler = WeightedPDBSampler(...)
154+
for batch in sampler:
155+
print(batch)
156+
```
157+
"""
158+
159+
# Load chain and interface mappings
160+
if not isinstance(chain_mapping_paths, list):
161+
chain_mapping_paths = [chain_mapping_paths]
162+
163+
chain_mapping = [pl.read_csv(path) for path in chain_mapping_paths]
164+
chain_mapping = pl.concat(chain_mapping)
165+
interface_mapping = pl.read_csv(interface_mapping_path)
166+
167+
# Filter out unwanted PDB IDs
168+
if len(pdb_ids_to_skip) > 0:
169+
chain_mapping = chain_mapping.filter(
170+
pl.col("pdb_id").is_in(pdb_ids_to_skip).not_()
171+
)
172+
interface_mapping = interface_mapping.filter(
173+
pl.col("pdb_id").is_in(pdb_ids_to_skip).not_()
174+
)
175+
176+
# Calculate weights for chains and interfaces
177+
self.alphas = {"prot": alpha_prot, "nuc": alpha_nuc, "ligand": alpha_ligand}
178+
self.betas = {"chain": beta_chain, "interface": beta_interface}
179+
self.batch_size = batch_size
180+
181+
chain_mapping.insert_column(
182+
len(chain_mapping.columns),
183+
compute_chain_weights(chain_mapping, self.alphas, self.betas["chain"]),
184+
)
185+
interface_mapping.insert_column(
186+
len(interface_mapping.columns),
187+
compute_interface_weights(
188+
interface_mapping, self.alphas, self.betas["interface"]
189+
),
190+
)
191+
192+
# Concatenate chain and interface mappings
193+
chain_mapping = chain_mapping.select(["pdb_id", "cluster_id", "weight"])
194+
195+
num_chain_clusters = chain_mapping.get_column("cluster_id").max() + 1
196+
interface_mapping = interface_mapping.with_columns(
197+
(pl.col("interface_cluster_id") + num_chain_clusters).alias("cluster_id")
198+
)
199+
interface_mapping = interface_mapping.select(["pdb_id", "cluster_id", "weight"])
200+
self.mappings = chain_mapping.extend(interface_mapping)
201+
202+
# Normalize weights
203+
self.weights = self.mappings.get_column("weight").to_numpy()
204+
self.weights = self.weights / self.weights.sum()
205+
206+
def __len__(self) -> int:
207+
return len(self.mappings) // self.batch_size
208+
209+
def __iter__(self) -> Iterator[list[str]]:
210+
while True:
211+
yield self.sample(self.batch_size)
212+
213+
def sample(self, batch_size: int) -> list[str]:
214+
indices = np.random.choice(len(self.mappings), size=batch_size, p=self.weights)
215+
return self.mappings.get_column("pdb_id").gather(indices).to_list()
216+
217+
def cluster_based_sample(self, batch_size: int) -> list[str]:
218+
"""
219+
Samples PDB IDs based on cluster IDs. For each batch, a number of cluster IDs
220+
are selected randomly, and a PDB ID is sampled from each cluster based on the
221+
weights of the chains/interfaces in the cluster.
222+
223+
Warning! Significantly slower than the regular `sample` method.
224+
"""
225+
cluster_ids = self.mappings.get_column("cluster_id").unique().sample(batch_size)
226+
227+
pdb_ids = []
228+
for cluster_id in cluster_ids:
229+
cluster = self.mappings.filter(pl.col("cluster_id") == cluster_id)
230+
if len(cluster) == 1:
231+
pdb_ids.append(cluster.item(0, "pdb_id"))
232+
continue
233+
cluster_weights = cluster.get_column("weight").to_numpy()
234+
cluster_weights = cluster_weights / cluster_weights.sum()
235+
idx = np.random.choice(len(cluster), p=cluster_weights)
236+
pdb_ids.append(cluster.item(idx, "pdb_id"))
237+
238+
return pdb_ids
239+
240+
241+
if __name__ == "__main__":
242+
interface_mapping_path = (
243+
"data/pdb_data/data_caches/clusterings/interface_cluster_mapping.csv"
244+
)
245+
chain_mapping_paths = [
246+
"data/pdb_data/data_caches/clusterings/ligand_chain_cluster_mapping.csv",
247+
"data/pdb_data/data_caches/clusterings/nucleic_acid_chain_cluster_mapping.csv",
248+
"data/pdb_data/data_caches/clusterings/peptide_chain_cluster_mapping.csv",
249+
"data/pdb_data/data_caches/clusterings/protein_chain_cluster_mapping.csv",
250+
]
251+
252+
dataset = WeightedSamplerPDB(
253+
chain_mapping_paths=chain_mapping_paths,
254+
interface_mapping_path=interface_mapping_path,
255+
batch_size=64,
256+
)
257+
258+
print(dataset.sample(64))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
pdb_id,interface_chain_id_1,interface_chain_id_2,interface_molecule_id_1,interface_molecule_id_2,interface_chain_cluster_id_1,interface_chain_cluster_id_2,interface_cluster_id
2+
121d,A,B,ligand-NT,nucleic_acid,0,4147,0
3+
121d,A,B,nucleic_acid,nucleic_acid,4147,4147,1
4+
420d,A,B,nucleic_acid,nucleic_acid,4148,4148,2
5+
113d,A,B,nucleic_acid,nucleic_acid,4149,4149,3
6+
320d,A,B,nucleic_acid,ligand-SPM,4150,4,4
7+
320d,A,B,nucleic_acid,nucleic_acid,4150,4150,5
8+
127d,A,B,nucleic_acid,ligand-HT,4151,5,6
9+
127d,A,B,nucleic_acid,nucleic_acid,4151,4151,7
10+
426d,A,B,nucleic_acid,nucleic_acid,4153,4153,8
11+
313d,A,B,ligand-NCO,nucleic_acid,9,4154,9
12+
313d,A,B,nucleic_acid,nucleic_acid,4154,4154,10
13+
326d,A,B,nucleic_acid,ligand-SPM,4150,4,4
14+
326d,A,B,nucleic_acid,nucleic_acid,4150,4150,5
15+
412d,A,B,ligand-MG,nucleic_acid,8,4156,11
16+
412d,A,B,nucleic_acid,nucleic_acid,4156,4156,12
17+
312d,A,B,nucleic_acid,ligand-NCO,4154,9,9
18+
312d,A,B,nucleic_acid,nucleic_acid,4154,4157,13
19+
11gs,A,B,protein,ligand-GSH,17328,11,14
20+
11gs,A,B,ligand-GSH,protein,11,17328,14
21+
11gs,A,B,protein,protein,17328,17328,15
22+
405d,A,B,nucleic_acid,nucleic_acid,4158,4158,16
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
pdb_id,chain_id,molecule_id,cluster_id
2+
121d,A,ligand-NT,0
3+
113l,A,ligand-CL,1
4+
113l,A,ligand-BME,2
5+
420d,A,ligand-NA,3
6+
320d,B,ligand-SPM,4
7+
127d,B,ligand-HT,5
8+
220l,A,ligand-CL,1
9+
220l,A,ligand-BME,2
10+
220l,A,ligand-BNZ,6
11+
426d,A,ligand-CA,7
12+
313d,A,ligand-MG,8
13+
313d,A,ligand-NCO,9
14+
220d,A,ligand-BA,10
15+
326d,B,ligand-SPM,4
16+
412d,A,ligand-MG,8
17+
412d,B,ligand-MG,8
18+
120l,A,ligand-CL,1
19+
120l,A,ligand-BME,2
20+
312d,B,ligand-MG,8
21+
312d,B,ligand-NCO,9
22+
11gs,A,ligand-GSH,11
23+
11gs,A,ligand-EAA,12
24+
11gs,B,ligand-GSH,11
25+
11gs,B,ligand-EAA,12
26+
226d,B,ligand-BIZ,13
27+
212l,A,ligand-HED,14
28+
306d,A,ligand-DMY,15
29+
306d,A,ligand-MG,8
30+
126l,A,ligand-CL,1
31+
126l,A,ligand-BME,2
32+
212d,A,ligand-NCO,9
33+
403d,A,ligand-HT1,16
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
pdb_id,chain_id,molecule_id,cluster_id
2+
2qnf,C,nucleic_acid,0
3+
2qnf,D,nucleic_acid,1
4+
2qnf,E,nucleic_acid,2
5+
2qnf,F,nucleic_acid,3
6+
2qnc,C,nucleic_acid,4
7+
2qnc,D,nucleic_acid,5
8+
2wiz,C,nucleic_acid,5
9+
2wiz,D,nucleic_acid,5
10+
2wj0,C,nucleic_acid,5
11+
2wj0,D,nucleic_acid,5
12+
2qnc,E,nucleic_acid,6
13+
2qnc,F,nucleic_acid,7
14+
6qn3,B,nucleic_acid,8
15+
6qn3,A,nucleic_acid,8
16+
3qoq,F,nucleic_acid,9
17+
3qoq,E,nucleic_acid,9
18+
4qoz,A,nucleic_acid,10
19+
4l8r,A,nucleic_acid,10
20+
4qoz,C,nucleic_acid,11
21+
1qp7,B,nucleic_acid,12
22+
1qp0,B,nucleic_acid,12
23+
1qph,A,nucleic_acid,13
24+
1qpi,A,nucleic_acid,14
25+
1qps,B,nucleic_acid,15
26+
1qps,C,nucleic_acid,16

0 commit comments

Comments
 (0)