Skip to content

Commit fe01f5a

Browse files
authored
Merge pull request #116 from ChEB-AI/feature/better-partial-data
allow negative samples (and regulate their amount for partial data)
2 parents ac8cf63 + a47a675 commit fe01f5a

File tree

2 files changed

+70
-8
lines changed

2 files changed

+70
-8
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
301301
data = pd.DataFrame(data)
302302
data = data[~data["SMILES"].isnull()]
303303
data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]]
304-
# This filters the DataFrame to include only the rows where at least one value in the row from 4th column
305-
# onwards is True/non-zero.
306-
data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
304+
307305
return data
308306

309307
# ------------------------------ Phase: Setup data -----------------------------------
@@ -702,18 +700,24 @@ class ChEBIOverXPartial(ChEBIOverX):
702700
top_class_id (int): The ID of the top class from which to extract subclasses.
703701
"""
704702

705-
def __init__(self, top_class_id: int, **kwargs):
703+
def __init__(self, top_class_id: int, external_data_ratio: float, **kwargs):
706704
"""
707705
Initializes the ChEBIOverXPartial dataset.
708706
709707
Args:
710708
top_class_id (int): The ID of the top class from which to extract subclasses.
711709
**kwargs: Additional keyword arguments passed to the superclass initializer.
710+
external_data_ratio (float): How much external data (i.e., samples where top_class_id
711+
is no positive label) to include in the dataset. 0 means no external data, 1 means
712+
the maximum amount (i.e., the complete ChEBI dataset).
712713
"""
713714
if "top_class_id" not in kwargs:
714715
kwargs["top_class_id"] = top_class_id
716+
if "external_data_ratio" not in kwargs:
717+
kwargs["external_data_ratio"] = external_data_ratio
715718

716719
self.top_class_id: int = top_class_id
720+
self.external_data_ratio: float = external_data_ratio
717721
super().__init__(**kwargs)
718722

719723
@property
@@ -727,7 +731,7 @@ def processed_dir_main(self) -> str:
727731
return os.path.join(
728732
self.base_dir,
729733
self._name,
730-
f"partial_{self.top_class_id}",
734+
f"partial_{self.top_class_id}_ext_ratio_{self.external_data_ratio:.2f}",
731735
"processed",
732736
)
733737

@@ -746,9 +750,53 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
746750
descendants of the top class ID.
747751
"""
748752
g = super()._extract_class_hierarchy(chebi_path)
749-
g = g.subgraph(list(g.successors(self.top_class_id)) + [self.top_class_id])
753+
top_class_successors = list(g.successors(self.top_class_id)) + [
754+
self.top_class_id
755+
]
756+
external_nodes = list(set(n for n in g.nodes if n not in top_class_successors))
757+
if 0 < self.external_data_ratio < 1:
758+
n_external_nodes = int(
759+
len(top_class_successors)
760+
* self.external_data_ratio
761+
/ (1 - self.external_data_ratio)
762+
)
763+
print(
764+
f"Extracting {n_external_nodes} external nodes from the ChEBI dataset (ratio: {self.external_data_ratio:.2f})"
765+
)
766+
external_nodes = external_nodes[: int(n_external_nodes)]
767+
elif self.external_data_ratio == 0:
768+
external_nodes = []
769+
770+
g = g.subgraph(top_class_successors + external_nodes)
771+
print(
772+
f"Subgraph contains {len(g.nodes)} nodes, of which {len(top_class_successors)} are subclasses of the top class ID {self.top_class_id}."
773+
)
750774
return g
751775

776+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
777+
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
778+
smiles = nx.get_node_attributes(g, "smiles")
779+
nodes = list(
780+
sorted(
781+
{
782+
node
783+
for node in g.nodes
784+
if sum(
785+
1 if smiles[s] is not None else 0 for s in g.successors(node)
786+
)
787+
>= self.THRESHOLD
788+
and (
789+
self.top_class_id in g.predecessors(node)
790+
or node == self.top_class_id
791+
)
792+
}
793+
)
794+
)
795+
filename = "classes.txt"
796+
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
797+
fout.writelines(str(node) + "\n" for node in nodes)
798+
return nodes
799+
752800

753801
class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50):
754802
"""
@@ -842,7 +890,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
842890

843891

844892
atom_index = (
845-
"\*",
893+
r"\*",
846894
"H",
847895
"He",
848896
"Li",
@@ -1473,3 +1521,15 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
14731521
]
14741522

14751523
JCI_500_COLUMNS_INT = [int(n.split(":")[-1]) for n in JCI_500_COLUMNS]
1524+
1525+
if __name__ == "__main__":
1526+
data_module_05 = ChEBIOver50Partial(
1527+
chebi_version=241,
1528+
splits_file_path=os.path.join(
1529+
"data", "chebi_v241", "ChEBI50", "splits_80_10_10.csv"
1530+
),
1531+
top_class_id=22712,
1532+
external_data_ratio=0.5,
1533+
)
1534+
data_module_05.prepare_data()
1535+
data_module_05.setup()

tests/unit/dataset_classes/testChebiOverXPartial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def setUpClass(cls, mock_makedirs) -> None:
1616
"""
1717
Set up the ChEBIOverXPartial instance with a mock processed directory path and a test graph.
1818
"""
19-
cls.chebi_extractor = ChEBIOverXPartial(top_class_id=11111, chebi_version=231)
19+
cls.chebi_extractor = ChEBIOverXPartial(
20+
top_class_id=11111, external_data_ratio=0, chebi_version=231
21+
)
2022
cls.test_graph = ChebiMockOntology.get_transitively_closed_graph()
2123

2224
@patch(

0 commit comments

Comments
 (0)