1313import pickle
1414from abc import ABC
1515from collections import OrderedDict
16- from typing import Any , Dict , Generator , List , Literal , Optional , Tuple , Union
16+ from typing import (
17+ TYPE_CHECKING ,
18+ Any ,
19+ Dict ,
20+ Generator ,
21+ List ,
22+ Literal ,
23+ Optional ,
24+ Tuple ,
25+ Union ,
26+ )
1727
18- import fastobo
19- import networkx as nx
2028import pandas as pd
21- import requests
2229import torch
2330
2431from chebai .preprocessing import reader as dr
2532from chebai .preprocessing .datasets .base import XYBaseDataModule , _DynamicDataset
2633
34+ if TYPE_CHECKING :
35+ import fastobo
36+ import networkx as nx
37+
2738# exclude some entities from the dataset because the violate disjointness axioms
2839CHEBI_BLACKLIST = [
2940 194026 ,
@@ -214,6 +225,8 @@ def _load_chebi(self, version: int) -> str:
214225 Returns:
215226 str: The file path of the loaded ChEBI ontology.
216227 """
228+ import requests
229+
217230 chebi_name = self .raw_file_names_dict ["chebi" ]
218231 chebi_path = os .path .join (self .raw_dir , chebi_name )
219232 if not os .path .isfile (chebi_path ):
@@ -225,7 +238,7 @@ def _load_chebi(self, version: int) -> str:
225238 open (chebi_path , "wb" ).write (r .content )
226239 return chebi_path
227240
228- def _extract_class_hierarchy (self , data_path : str ) -> nx .DiGraph :
241+ def _extract_class_hierarchy (self , data_path : str ) -> " nx.DiGraph" :
229242 """
230243 Extracts the class hierarchy from the ChEBI ontology.
231244 Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -237,6 +250,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
237250 Returns:
238251 nx.DiGraph: The class hierarchy.
239252 """
253+ import fastobo
254+ import networkx as nx
255+
240256 with open (data_path , encoding = "utf-8" ) as chebi :
241257 chebi = "\n " .join (line for line in chebi if not line .startswith ("xref:" ))
242258
@@ -266,7 +282,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
266282 print ("Compute transitive closure" )
267283 return nx .transitive_closure_dag (g )
268284
269- def _graph_to_raw_dataset (self , g : nx .DiGraph ) -> pd .DataFrame :
285+ def _graph_to_raw_dataset (self , g : " nx.DiGraph" ) -> pd .DataFrame :
270286 """
271287 Converts the graph to a raw dataset.
272288 Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -278,6 +294,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
278294 Returns:
279295 pd.DataFrame: The raw dataset created from the graph.
280296 """
297+ import networkx as nx
298+
281299 smiles = nx .get_node_attributes (g , "smiles" )
282300 names = nx .get_node_attributes (g , "name" )
283301
@@ -590,7 +608,7 @@ def _name(self) -> str:
590608 """
591609 return f"ChEBI{ self .THRESHOLD } "
592610
593- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
611+ def select_classes (self , g : " nx.DiGraph" , * args , ** kwargs ) -> List :
594612 """
595613 Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.
596614
@@ -615,6 +633,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
615633 - The `THRESHOLD` attribute should be defined in the subclass of this class.
616634 - Nodes without a 'smiles' attribute are ignored in the successor count.
617635 """
636+ import networkx as nx
637+
618638 smiles = nx .get_node_attributes (g , "smiles" )
619639 nodes = list (
620640 sorted (
@@ -753,7 +773,7 @@ def processed_dir_main(self) -> str:
753773 "processed" ,
754774 )
755775
756- def _extract_class_hierarchy (self , chebi_path : str ) -> nx .DiGraph :
776+ def _extract_class_hierarchy (self , chebi_path : str ) -> " nx.DiGraph" :
757777 """
758778 Extracts a subset of ChEBI based on subclasses of the top class ID.
759779
@@ -791,8 +811,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
791811 )
792812 return g
793813
794- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
814+ def select_classes (self , g : " nx.DiGraph" , * args , ** kwargs ) -> List :
795815 """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
816+ import networkx as nx
817+
796818 smiles = nx .get_node_attributes (g , "smiles" )
797819 nodes = list (
798820 sorted (
@@ -868,7 +890,7 @@ def chebi_to_int(s: str) -> int:
868890 return int (s [s .index (":" ) + 1 :])
869891
870892
871- def term_callback (doc : fastobo .term .TermFrame ) -> Union [Dict , bool ]:
893+ def term_callback (doc : " fastobo.term.TermFrame" ) -> Union [Dict , bool ]:
872894 """
873895 Extracts information from a ChEBI term document.
874896 This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
@@ -885,6 +907,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
885907 - "name": The name of the ChEBI term.
886908 - "smiles": The SMILES string associated with the ChEBI term, if available.
887909 """
910+ import fastobo
911+
888912 parts = set ()
889913 parents = []
890914 name = None
0 commit comments