1515from abc import ABC
1616from collections import OrderedDict
1717from itertools import cycle , permutations , product
18- from typing import Any , Generator , Optional , Union
18+ from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Union
1919
20- import fastobo
21- import networkx as nx
2220import pandas as pd
23- import requests
2421import torch
2522from rdkit import Chem
2623
2724from chebai .preprocessing import reader as dr
2825from chebai .preprocessing .datasets .base import XYBaseDataModule , _DynamicDataset
2926
27+ if TYPE_CHECKING :
28+ import fastobo
29+ import networkx as nx
30+
3031# exclude some entities from the dataset because the violate disjointness axioms
3132CHEBI_BLACKLIST = [
3233 194026 ,
@@ -236,6 +237,8 @@ def _load_chebi(self, version: int) -> str:
236237 Returns:
237238 str: The file path of the loaded ChEBI ontology.
238239 """
240+ import requests
241+
239242 chebi_name = self .raw_file_names_dict ["chebi" ]
240243 chebi_path = os .path .join (self .raw_dir , chebi_name )
241244 if not os .path .isfile (chebi_path ):
@@ -247,7 +250,7 @@ def _load_chebi(self, version: int) -> str:
247250 open (chebi_path , "wb" ).write (r .content )
248251 return chebi_path
249252
250- def _extract_class_hierarchy (self , data_path : str ) -> nx .DiGraph :
253+ def _extract_class_hierarchy (self , data_path : str ) -> " nx.DiGraph" :
251254 """
252255 Extracts the class hierarchy from the ChEBI ontology.
253256 Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -259,6 +262,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
259262 Returns:
260263 nx.DiGraph: The class hierarchy.
261264 """
265+ import fastobo
266+ import networkx as nx
267+
262268 with open (data_path , encoding = "utf-8" ) as chebi :
263269 chebi = "\n " .join (line for line in chebi if not line .startswith ("xref:" ))
264270
@@ -286,7 +292,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
286292 print ("Compute transitive closure" )
287293 return nx .transitive_closure_dag (g )
288294
289- def _graph_to_raw_dataset (self , g : nx .DiGraph ) -> pd .DataFrame :
295+ def _graph_to_raw_dataset (self , g : " nx.DiGraph" ) -> pd .DataFrame :
290296 """
291297 Converts the graph to a raw dataset.
292298 Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -298,6 +304,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
298304 Returns:
299305 pd.DataFrame: The raw dataset created from the graph.
300306 """
307+ import networkx as nx
308+
301309 smiles = nx .get_node_attributes (g , "smiles" )
302310 names = nx .get_node_attributes (g , "name" )
303311
@@ -696,7 +704,7 @@ def _name(self) -> str:
696704 """
697705 return f"ChEBI{ self .THRESHOLD } "
698706
699- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> list :
707+ def select_classes (self , g : " nx.DiGraph" , * args , ** kwargs ) -> List :
700708 """
701709 Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.
702710
@@ -721,6 +729,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list:
721729 - The `THRESHOLD` attribute should be defined in the subclass of this class.
722730 - Nodes without a 'smiles' attribute are ignored in the successor count.
723731 """
732+ import networkx as nx
733+
724734 smiles = nx .get_node_attributes (g , "smiles" )
725735 nodes = list (
726736 sorted (
@@ -859,7 +869,7 @@ def processed_dir_main(self) -> str:
859869 "processed" ,
860870 )
861871
862- def _extract_class_hierarchy (self , chebi_path : str ) -> nx .DiGraph :
872+ def _extract_class_hierarchy (self , chebi_path : str ) -> " nx.DiGraph" :
863873 """
864874 Extracts a subset of ChEBI based on subclasses of the top class ID.
865875
@@ -897,8 +907,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
897907 )
898908 return g
899909
900- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> list :
910+ def select_classes (self , g : " nx.DiGraph" , * args , ** kwargs ) -> List :
901911 """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
912+ import networkx as nx
913+
902914 smiles = nx .get_node_attributes (g , "smiles" )
903915 nodes = list (
904916 sorted (
@@ -958,7 +970,7 @@ def chebi_to_int(s: str) -> int:
958970 return int (s [s .index (":" ) + 1 :])
959971
960972
961- def term_callback (doc : fastobo .term .TermFrame ) -> Union [dict , bool ]:
973+ def term_callback (doc : " fastobo.term.TermFrame" ) -> Union [Dict , bool ]:
962974 """
963975 Extracts information from a ChEBI term document.
964976 This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
@@ -975,6 +987,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[dict, bool]:
975987 - "name": The name of the ChEBI term.
976988 - "smiles": The SMILES string associated with the ChEBI term, if available.
977989 """
990+ import fastobo
991+
978992 parts = set ()
979993 parents = []
980994 name = None
0 commit comments