Skip to content

Commit 1470e93

Browse files
committed
Merge branch 'protein_prediction' into additional_unit_tests
2 parents e3c4b6e + 6511086 commit 1470e93

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
__all__ = ["GOUniProtOver250", "GOUniProtOver50"]
1212

1313
import gzip
14+
import itertools
1415
import os
1516
import shutil
1617
from abc import ABC, abstractmethod
@@ -340,7 +341,18 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
340341
print(f"Processing graph")
341342

342343
data_df = self._get_swiss_to_go_mapping()
343-
344+
# add ancestors to go ids
345+
data_df["go_ids"] = data_df["go_ids"].apply(
346+
lambda go_ids: list(
347+
itertools.chain.from_iterable(
348+
[
349+
[go_id] + list(g.predecessors(go_id))
350+
for go_id in go_ids
351+
if go_id in g.nodes
352+
]
353+
)
354+
)
355+
)
344356
# Initialize the GO term labels/columns to False
345357
selected_classes = self.select_classes(g, data_df=data_df)
346358
new_label_columns = pd.DataFrame(
@@ -642,20 +654,8 @@ def select_classes(
642654
# https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L59-L77
643655
go_term_annot: Dict[int, int] = {}
644656
for idx, row in data_df.iterrows():
645-
# Set will contain go terms associated with the protein, along with all the ancestors of those
646-
# associated go terms
647-
associated_go_ids_with_ancestors = set()
648-
649-
# Collect all ancestors of the GO terms associated with this protein
650-
for go_id in row["go_ids"]:
651-
if go_id in g.nodes:
652-
associated_go_ids_with_ancestors.add(go_id)
653-
associated_go_ids_with_ancestors.update(
654-
g.predecessors(go_id)
655-
) # Add all predecessors (ancestors) of go_id
656-
657657
# Count the annotations for each go_id **`per protein`**
658-
for go_id in associated_go_ids_with_ancestors:
658+
for go_id in row["go_ids"]:
659659
if go_id not in go_term_annot:
660660
go_term_annot[go_id] = 0
661661
go_term_annot[go_id] += 1

0 commit comments

Comments
 (0)