11import os
22import random
33from abc import ABC , abstractmethod
4- from typing import Any , Dict , Generator , List , Optional , Tuple , Union
4+ from pathlib import Path
5+ from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Tuple , Union
56
67import lightning as pl
7- import networkx as nx
88import pandas as pd
99import torch
1010import tqdm
11- from iterstrat .ml_stratifiers import (
12- MultilabelStratifiedKFold ,
13- MultilabelStratifiedShuffleSplit ,
14- )
1511from lightning .pytorch .core .datamodule import LightningDataModule
1612from lightning_utilities .core .rank_zero import rank_zero_info
17- from sklearn .model_selection import StratifiedShuffleSplit
1813from torch .utils .data import DataLoader
1914
2015from chebai .preprocessing import reader as dr
2116
17+ if TYPE_CHECKING :
18+ import networkx as nx
19+
2220
2321class XYBaseDataModule (LightningDataModule ):
2422 """
@@ -419,10 +417,17 @@ def prepare_data(self, *args, **kwargs) -> None:
419417
420418 self ._prepare_data_flag += 1
421419 self ._perform_data_preparation (* args , ** kwargs )
420+ self ._after_prepare_data (* args , ** kwargs )
422421
423422 def _perform_data_preparation (self , * args , ** kwargs ) -> None :
424423 raise NotImplementedError
425424
425+ def _after_prepare_data (self , * args , ** kwargs ) -> None :
426+ """
427+ Hook to perform additional pre-processing after pre-processed data is available.
428+ """
429+ ...
430+
426431 def setup (self , * args , ** kwargs ) -> None :
427432 """
428433 Setup the data module.
@@ -464,14 +469,17 @@ def _set_processed_data_props(self):
464469 - self._num_of_labels: Number of target labels in the dataset.
465470 - self._feature_vector_size: Maximum feature vector length across all data points.
466471 """
467- data_pt = torch .load (
468- os .path .join (self .processed_dir , self .processed_file_names_dict ["data" ]),
469- weights_only = False ,
472+ pt_file_path = os .path .join (
473+ self .processed_dir , self .processed_file_names_dict ["data" ]
470474 )
475+ data_pt = torch .load (pt_file_path , weights_only = False )
471476
472477 self ._num_of_labels = len (data_pt [0 ]["labels" ])
473478 self ._feature_vector_size = max (len (d ["features" ]) for d in data_pt )
474479
480+ print (
481+ f"Number of samples in encoded data ({ pt_file_path } ): { len (data_pt )} samples"
482+ )
475483 print (f"Number of labels for loaded data: { self ._num_of_labels } " )
476484 print (f"Feature vector size: { self ._feature_vector_size } " )
477485
@@ -734,6 +742,7 @@ def __init__(
734742 self .splits_file_path = self ._validate_splits_file_path (
735743 kwargs .get ("splits_file_path" , None )
736744 )
745+ self ._data_pkl_filename : str = "data.pkl"
737746
738747 @staticmethod
739748 def _validate_splits_file_path (splits_file_path : Optional [str ]) -> Optional [str ]:
@@ -818,7 +827,7 @@ def _download_required_data(self) -> str:
818827 pass
819828
820829 @abstractmethod
821- def _extract_class_hierarchy (self , data_path : str ) -> nx .DiGraph :
830+ def _extract_class_hierarchy (self , data_path : str ) -> " nx.DiGraph" :
822831 """
823832 Extracts the class hierarchy from the data.
824833 Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -833,7 +842,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
833842 pass
834843
835844 @abstractmethod
836- def _graph_to_raw_dataset (self , graph : nx .DiGraph ) -> pd .DataFrame :
845+ def _graph_to_raw_dataset (self , graph : " nx.DiGraph" ) -> pd .DataFrame :
837846 """
838847 Converts the graph to a raw dataset.
839848 Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -848,7 +857,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
848857 pass
849858
850859 @abstractmethod
851- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
860+ def select_classes (self , g : " nx.DiGraph" , * args , ** kwargs ) -> List :
852861 """
853862 Selects classes from the dataset based on a specified criteria.
854863
@@ -872,6 +881,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
872881 """
873882 pd .to_pickle (data , open (os .path .join (self .processed_dir_main , filename ), "wb" ))
874883
884+ def get_processed_pickled_df_file (self , filename : str ) -> Optional [pd .DataFrame ]:
885+ """
886+ Gets the processed dataset pickle file.
887+
888+ Args:
889+ filename (str): The filename for the pickle file.
890+
891+ Returns:
892+ pd.DataFrame: The processed dataset as a DataFrame.
893+ """
894+ file_path = Path (self .processed_dir_main ) / filename
895+ if file_path .exists ():
896+ return pd .read_pickle (file_path )
897+ return None
898+
875899 # ------------------------------ Phase: Setup data -----------------------------------
876900 def setup_processed (self ) -> None :
877901 """
@@ -910,7 +934,9 @@ def _get_data_size(input_file_path: str) -> int:
910934 int: The size of the data.
911935 """
912936 with open (input_file_path , "rb" ) as f :
913- return len (pd .read_pickle (f ))
937+ df = pd .read_pickle (f )
938+ print (f"Processed data size ({ input_file_path } ): { len (df )} rows" )
939+ return len (df )
914940
915941 @abstractmethod
916942 def _load_dict (self , input_file_path : str ) -> Generator [Dict [str , Any ], None , None ]:
@@ -1023,6 +1049,9 @@ def get_test_split(
10231049 Raises:
10241050 ValueError: If the DataFrame does not contain a column named "labels".
10251051 """
1052+ from iterstrat .ml_stratifiers import MultilabelStratifiedShuffleSplit
1053+ from sklearn .model_selection import StratifiedShuffleSplit
1054+
10261055 print ("Get test data split" )
10271056
10281057 labels_list = df ["labels" ].tolist ()
@@ -1060,6 +1089,12 @@ def get_train_val_splits_given_test(
10601089 and validation DataFrames. The keys are the names of the train and validation sets, and the values
10611090 are the corresponding DataFrames.
10621091 """
1092+ from iterstrat .ml_stratifiers import (
1093+ MultilabelStratifiedKFold ,
1094+ MultilabelStratifiedShuffleSplit ,
1095+ )
1096+ from sklearn .model_selection import StratifiedShuffleSplit
1097+
10631098 print ("Split dataset into train / val with given test set" )
10641099
10651100 test_ids = test_df ["ident" ].tolist ()
@@ -1217,7 +1252,7 @@ def processed_main_file_names_dict(self) -> dict:
12171252 dict: A dictionary mapping dataset key to their respective file names.
12181253 For example, {"data": "data.pkl"}.
12191254 """
1220- return {"data" : "data.pkl" }
1255+ return {"data" : self . _data_pkl_filename }
12211256
12221257 @property
12231258 def raw_file_names (self ) -> List [str ]:
0 commit comments