diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 1df144d9..9fa1c1c7 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -301,9 +301,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: data = pd.DataFrame(data) data = data[~data["SMILES"].isnull()] data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]] - # This filters the DataFrame to include only the rows where at least one value in the row from 4th column - # onwards is True/non-zero. - data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data # ------------------------------ Phase: Setup data ----------------------------------- @@ -702,18 +700,24 @@ class ChEBIOverXPartial(ChEBIOverX): top_class_id (int): The ID of the top class from which to extract subclasses. """ - def __init__(self, top_class_id: int, **kwargs): + def __init__(self, top_class_id: int, external_data_ratio: float, **kwargs): """ Initializes the ChEBIOverXPartial dataset. Args: top_class_id (int): The ID of the top class from which to extract subclasses. **kwargs: Additional keyword arguments passed to the superclass initializer. + external_data_ratio (float): How much external data (i.e., samples where top_class_id + is no positive label) to include in the dataset. 0 means no external data, 1 means + the maximum amount (i.e., the complete ChEBI dataset). """ if "top_class_id" not in kwargs: kwargs["top_class_id"] = top_class_id + if "external_data_ratio" not in kwargs: + kwargs["external_data_ratio"] = external_data_ratio self.top_class_id: int = top_class_id + self.external_data_ratio: float = external_data_ratio super().__init__(**kwargs) @property @@ -727,7 +731,7 @@ def processed_dir_main(self) -> str: return os.path.join( self.base_dir, self._name, - f"partial_{self.top_class_id}", + f"partial_{self.top_class_id}_ext_ratio_{self.external_data_ratio:.2f}", "processed", ) @@ -746,9 +750,53 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: descendants of the top class ID. """ g = super()._extract_class_hierarchy(chebi_path) - g = g.subgraph(list(g.successors(self.top_class_id)) + [self.top_class_id]) + top_class_successors = list(g.successors(self.top_class_id)) + [ + self.top_class_id + ] + external_nodes = list(set(n for n in g.nodes if n not in top_class_successors)) + if 0 < self.external_data_ratio < 1: + n_external_nodes = int( + len(top_class_successors) + * self.external_data_ratio + / (1 - self.external_data_ratio) + ) + print( + f"Extracting {n_external_nodes} external nodes from the ChEBI dataset (ratio: {self.external_data_ratio:.2f})" + ) + external_nodes = external_nodes[: int(n_external_nodes)] + elif self.external_data_ratio == 0: + external_nodes = [] + + g = g.subgraph(top_class_successors + external_nodes) + print( + f"Subgraph contains {len(g.nodes)} nodes, of which {len(top_class_successors)} are subclasses of the top class ID {self.top_class_id}." + ) return g + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself).""" + smiles = nx.get_node_attributes(g, "smiles") + nodes = list( + sorted( + { + node + for node in g.nodes + if sum( + 1 if smiles[s] is not None else 0 for s in g.successors(node) + ) + >= self.THRESHOLD + and ( + self.top_class_id in g.predecessors(node) + or node == self.top_class_id + ) + } + ) + ) + filename = "classes.txt" + with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: + fout.writelines(str(node) + "\n" for node in nodes) + return nodes + class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50): """ @@ -842,7 +890,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: atom_index = ( - "\*", + r"\*", "H", "He", "Li", @@ -1473,3 +1521,15 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: ] JCI_500_COLUMNS_INT = [int(n.split(":")[-1]) for n in JCI_500_COLUMNS] + +if __name__ == "__main__": + data_module_05 = ChEBIOver50Partial( + chebi_version=241, + splits_file_path=os.path.join( + "data", "chebi_v241", "ChEBI50", "splits_80_10_10.csv" + ), + top_class_id=22712, + external_data_ratio=0.5, + ) + data_module_05.prepare_data() + data_module_05.setup() diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py index 76584ebf..0e263335 100644 --- a/tests/unit/dataset_classes/testChebiOverXPartial.py +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -16,7 +16,9 @@ def setUpClass(cls, mock_makedirs) -> None: """ Set up the ChEBIOverXPartial instance with a mock processed directory path and a test graph. """ - cls.chebi_extractor = ChEBIOverXPartial(top_class_id=11111, chebi_version=231) + cls.chebi_extractor = ChEBIOverXPartial( + top_class_id=11111, external_data_ratio=0, chebi_version=231 + ) cls.test_graph = ChebiMockOntology.get_transitively_closed_graph() @patch(