@@ -301,9 +301,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
301301 data = pd .DataFrame (data )
302302 data = data [~ data ["SMILES" ].isnull ()]
303303 data = data [[name not in CHEBI_BLACKLIST for name , _ in data .iterrows ()]]
304- # This filters the DataFrame to include only the rows where at least one value in the row from 4th column
305- # onwards is True/non-zero.
306- data = data [data .iloc [:, self ._LABELS_START_IDX :].any (axis = 1 )]
304+
307305 return data
308306
309307 # ------------------------------ Phase: Setup data -----------------------------------
@@ -702,18 +700,24 @@ class ChEBIOverXPartial(ChEBIOverX):
702700 top_class_id (int): The ID of the top class from which to extract subclasses.
703701 """
704702
705- def __init__ (self , top_class_id : int , ** kwargs ):
703+ def __init__ (self , top_class_id : int , external_data_ratio : float , ** kwargs ):
706704 """
707705 Initializes the ChEBIOverXPartial dataset.
708706
709707 Args:
710708 top_class_id (int): The ID of the top class from which to extract subclasses.
711709 **kwargs: Additional keyword arguments passed to the superclass initializer.
710+ external_data_ratio (float): How much external data (i.e., samples where top_class_id
711+ is no positive label) to include in the dataset. 0 means no external data, 1 means
712+ the maximum amount (i.e., the complete ChEBI dataset).
712713 """
713714 if "top_class_id" not in kwargs :
714715 kwargs ["top_class_id" ] = top_class_id
716+ if "external_data_ratio" not in kwargs :
717+ kwargs ["external_data_ratio" ] = external_data_ratio
715718
716719 self .top_class_id : int = top_class_id
720+ self .external_data_ratio : float = external_data_ratio
717721 super ().__init__ (** kwargs )
718722
719723 @property
@@ -727,7 +731,7 @@ def processed_dir_main(self) -> str:
727731 return os .path .join (
728732 self .base_dir ,
729733 self ._name ,
730- f"partial_{ self .top_class_id } " ,
734+ f"partial_{ self .top_class_id } _ext_ratio_ { self . external_data_ratio :.2f } " ,
731735 "processed" ,
732736 )
733737
@@ -746,9 +750,53 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
746750 descendants of the top class ID.
747751 """
748752 g = super ()._extract_class_hierarchy (chebi_path )
749- g = g .subgraph (list (g .successors (self .top_class_id )) + [self .top_class_id ])
753+ top_class_successors = list (g .successors (self .top_class_id )) + [
754+ self .top_class_id
755+ ]
756+ external_nodes = list (set (n for n in g .nodes if n not in top_class_successors ))
757+ if 0 < self .external_data_ratio < 1 :
758+ n_external_nodes = int (
759+ len (top_class_successors )
760+ * self .external_data_ratio
761+ / (1 - self .external_data_ratio )
762+ )
763+ print (
764+ f"Extracting { n_external_nodes } external nodes from the ChEBI dataset (ratio: { self .external_data_ratio :.2f} )"
765+ )
766+ external_nodes = external_nodes [: int (n_external_nodes )]
767+ elif self .external_data_ratio == 0 :
768+ external_nodes = []
769+
770+ g = g .subgraph (top_class_successors + external_nodes )
771+ print (
772+ f"Subgraph contains { len (g .nodes )} nodes, of which { len (top_class_successors )} are subclasses of the top class ID { self .top_class_id } ."
773+ )
750774 return g
751775
776+ def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
777+ """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
778+ smiles = nx .get_node_attributes (g , "smiles" )
779+ nodes = list (
780+ sorted (
781+ {
782+ node
783+ for node in g .nodes
784+ if sum (
785+ 1 if smiles [s ] is not None else 0 for s in g .successors (node )
786+ )
787+ >= self .THRESHOLD
788+ and (
789+ self .top_class_id in g .predecessors (node )
790+ or node == self .top_class_id
791+ )
792+ }
793+ )
794+ )
795+ filename = "classes.txt"
796+ with open (os .path .join (self .processed_dir_main , filename ), "wt" ) as fout :
797+ fout .writelines (str (node ) + "\n " for node in nodes )
798+ return nodes
799+
752800
753801class ChEBIOver50Partial (ChEBIOverXPartial , ChEBIOver50 ):
754802 """
@@ -842,7 +890,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
842890
843891
844892atom_index = (
845- "\*" ,
893+ r "\*" ,
846894 "H" ,
847895 "He" ,
848896 "Li" ,
@@ -1473,3 +1521,15 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
14731521]
14741522
14751523JCI_500_COLUMNS_INT = [int (n .split (":" )[- 1 ]) for n in JCI_500_COLUMNS ]
1524+
1525+ if __name__ == "__main__" :
1526+ data_module_05 = ChEBIOver50Partial (
1527+ chebi_version = 241 ,
1528+ splits_file_path = os .path .join (
1529+ "data" , "chebi_v241" , "ChEBI50" , "splits_80_10_10.csv"
1530+ ),
1531+ top_class_id = 22712 ,
1532+ external_data_ratio = 0.5 ,
1533+ )
1534+ data_module_05 .prepare_data ()
1535+ data_module_05 .setup ()
0 commit comments