Skip to content

Commit e9916cc

Browse files
committed
fix: allow custom scoring
1 parent c38b662 commit e9916cc

File tree

1 file changed

+43
-28
lines changed

1 file changed

+43
-28
lines changed

src/plinder/data/utils/annotations/get_similarity_scores.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def run_alignment(
199199
str(x)
200200
.replace("_xyz-enrich.cif.gz", "")
201201
.replace("_xyz-enrich.cif", "")
202+
.replace(".cif.gz", "")
203+
.replace(".cif", "")
202204
.replace("pdb_0000", "")[:4]
203205
for x in table["query"]
204206
]
@@ -212,6 +214,8 @@ def run_alignment(
212214
str(x)
213215
.replace("_xyz-enrich.cif.gz", "")
214216
.replace("_xyz-enrich.cif", "")
217+
.replace(".cif.gz", "")
218+
.replace(".cif", "")
215219
.replace("pdb_0000", "")[:4]
216220
for x in table["target"]
217221
]
@@ -422,6 +426,29 @@ def get_score_df(
422426
if not pdb_id_file.exists():
423427
LOG.info(f"get_score_df: pdb_id_file={pdb_id_file} does not exist")
424428
continue
429+
430+
# self.entries = {}
431+
entries_to_load = {pdb_id}
432+
if search_db != "pred" and pdb_id_file.exists():
433+
entries_to_load |= set(
434+
pd.read_parquet(pdb_id_file, columns=["target_pdb_id"])[
435+
"target_pdb_id"
436+
]
437+
)
438+
entries_to_load = entries_to_load.difference(self.entries.keys())
439+
LOG.info(f"entries_to_load pdb_id={pdb_id} {len(entries_to_load)}")
440+
LOG.info(
441+
f"loading {len(entries_to_load)} (additional) entries for {pdb_id}"
442+
)
443+
self.entries.update(
444+
load_entries_from_zips(
445+
data_dir=data_dir,
446+
pdb_ids=entries_to_load,
447+
load_for_scoring=True,
448+
max_protein_chains=20,
449+
max_ligand_chains=20,
450+
)
451+
)
425452
pdb_file = (
426453
self.db_dir
427454
/ f"{search_db}_{aln_type}"
@@ -430,30 +457,9 @@ def get_score_df(
430457
)
431458
pdb_file.parent.mkdir(exist_ok=True, parents=True)
432459
if overwrite or not pdb_file.exists():
433-
self.entries = {}
434-
entries_to_load = {pdb_id}
435-
if search_db != "pred" and pdb_id_file.exists():
436-
entries_to_load |= set(
437-
pd.read_parquet(pdb_id_file, columns=["target_pdb_id"])[
438-
"target_pdb_id"
439-
]
440-
)
441-
entries_to_load = entries_to_load.difference(self.entries.keys())
442-
LOG.info(f"entries_to_load pdb_id={pdb_id} {len(entries_to_load)}")
443-
LOG.info(
444-
f"loading {len(entries_to_load)} (additional) entries for {pdb_id}"
445-
)
446-
self.entries.update(
447-
load_entries_from_zips(
448-
data_dir=data_dir,
449-
pdb_ids=entries_to_load,
450-
load_for_scoring=True,
451-
)
452-
)
453-
454460
try:
455461
LOG.info(
456-
f"mapping aligmnet df for {pdb_id} to {search_db} for {aln_type}"
462+
f"mapping aligment df for {pdb_id} to {search_db} for {aln_type}"
457463
)
458464
self.map_alignment_df(pdb_id_file, aln_type, search_db).to_parquet(
459465
pdb_file, index=True
@@ -533,7 +539,12 @@ def map_alignment_df(
533539
df = pd.read_parquet(df_file)
534540
if aln_type == "foldseek":
535541
df["query"] = df["query"].replace(
536-
{"_xyz-enrich.cif.gz": "", "_xyz-enrich.cif": "", "pdb_0000": ""},
542+
{
543+
"_xyz-enrich.cif.gz": "",
544+
"_xyz-enrich.cif": "",
545+
"pdb_0000": "",
546+
".cif.gz": "",
547+
},
537548
regex=True,
538549
)
539550
if search_db == "pred":
@@ -542,7 +553,12 @@ def map_alignment_df(
542553
)
543554
else:
544555
df["target"] = df["target"].replace(
545-
{"_xyz-enrich.cif.gz": "", "_xyz-enrich.cif": "", "pdb_0000": ""},
556+
{
557+
"_xyz-enrich.cif.gz": "",
558+
"_xyz-enrich.cif": "",
559+
"pdb_0000": "",
560+
".cif.gz": "",
561+
},
546562
regex=True,
547563
)
548564
df["query_chain_mapped"] = (
@@ -754,9 +770,9 @@ def get_pocket_pli_scores(
754770
]:
755771
pocket_scores: _SimilarityScoreDictType = defaultdict(float)
756772
pli_scores: _SimilarityScoreDictType = defaultdict(float)
757-
pocket_length = query_system.num_pocket_residues
758-
pli_length = query_system.num_interactions
759-
pli_unique_length = query_system.num_unique_interactions
773+
pocket_length = query_system.proper_num_pocket_residues
774+
pli_length = query_system.proper_num_interactions
775+
pli_unique_length = query_system.proper_num_unique_interactions
760776
for q_instance_chain, t_instance_chain in alns:
761777
aln = alns[(q_instance_chain, t_instance_chain)]
762778
q_pocket = query_system.pocket_residues.get(q_instance_chain, {})
@@ -856,7 +872,6 @@ def get_scores_holo(
856872
q_chain
857873
].index.get_level_values("target_chain_mapped")
858874
)
859-
860875
for target_system_id in self.entries[target_entry].systems:
861876
target_system = self.entries[target_entry].systems[target_system_id]
862877
if target_system.system_type != "holo":

0 commit comments

Comments
 (0)