Skip to content

Commit f50b2dc

Browse files
authored
Update weighted_pdb_sampler.py (#98)
1 parent 499fb2b commit f50b2dc

File tree

1 file changed

+4
-18
lines changed

1 file changed

+4
-18
lines changed

alphafold3_pytorch/data/weighted_pdb_sampler.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,7 @@ def __init__(
186186

187187
chain_mapping = [pl.read_csv(path) for path in chain_mapping_paths]
188188
# Increment chain cluster IDs to avoid overlap
189-
chain_cluster_nums = [
190-
mapping.get_column("cluster_id").max() for mapping in chain_mapping
191-
]
189+
chain_cluster_nums = [mapping.get_column("cluster_id").max() for mapping in chain_mapping]
192190
for i in range(1, len(chain_mapping)):
193191
chain_mapping[i] = chain_mapping[i].with_columns(
194192
(pl.col("cluster_id") + sum(chain_cluster_nums[:i])).alias("cluster_id")
@@ -218,20 +216,6 @@ def __init__(
218216
compute_interface_weights(interface_mapping, self.alphas, self.betas["interface"]),
219217
)
220218

221-
# Add additional information to the cluster IDs
222-
chain_mapping = chain_mapping.with_columns(
223-
(pl.col("molecule_id") + "-" + pl.col("cluster_id").cast(pl.String)).alias("cluster_id")
224-
)
225-
interface_mapping = interface_mapping.with_columns(
226-
(
227-
pl.col("interface_molecule_id_1")
228-
+ "-"
229-
+ pl.col("interface_molecule_id_2")
230-
+ "-"
231-
+ pl.col("interface_cluster_id").cast(pl.String)
232-
).alias("interface_cluster_id")
233-
)
234-
235219
# Concatenate chain and interface mappings
236220
chain_mapping = chain_mapping.with_columns(
237221
[
@@ -247,7 +231,9 @@ def __init__(
247231
[
248232
pl.col("interface_chain_id_1").alias("chain_id_1"),
249233
pl.col("interface_chain_id_2").alias("chain_id_2"),
250-
pl.col("interface_cluster_id").alias("cluster_id"),
234+
(
235+
pl.col("interface_cluster_id") + chain_mapping.get_column("cluster_id").max()
236+
).alias("cluster_id"),
251237
]
252238
)
253239
interface_mapping = interface_mapping.select(

0 commit comments

Comments
 (0)