Skip to content

Commit 9992a15

Browse files
committed
logic to generate splits csv + use csv if provided
1 parent 1c4acea commit 9992a15

File tree

1 file changed

+83
-2
lines changed

1 file changed

+83
-2
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,51 @@ def __init__(
156156
**_init_kwargs,
157157
)
158158

159+
self.splits_file_path = self._validate_splits_file_path(
160+
kwargs.get("splits_file_path", None)
161+
)
162+
163+
@staticmethod
164+
def _validate_splits_file_path(splits_file_path=None):
165+
"""
166+
Validates the provided splits file path.
167+
168+
Args:
169+
splits_file_path (str or None): Path to the splits CSV file.
170+
171+
Returns:
172+
str or None: Validated splits file path if checks pass, None if splits_file_path is None.
173+
174+
Raises:
175+
FileNotFoundError: If the splits file does not exist.
176+
ValueError: If the splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file.
177+
"""
178+
if splits_file_path is None:
179+
return None
180+
181+
if not os.path.isfile(splits_file_path):
182+
raise FileNotFoundError(f"File {splits_file_path} does not exist")
183+
184+
file_size = os.path.getsize(splits_file_path)
185+
if file_size == 0:
186+
raise ValueError(f"File {splits_file_path} is empty")
187+
188+
# Check if the file has a CSV extension
189+
if not splits_file_path.lower().endswith(".csv"):
190+
raise ValueError(f"File {splits_file_path} is not a CSV file")
191+
192+
# Read the CSV file into a DataFrame
193+
splits_df = pd.read_csv(splits_file_path)
194+
195+
# Check if 'id' and 'split' columns are in the DataFrame
196+
required_columns = {"id", "split"}
197+
if not required_columns.issubset(splits_df.columns):
198+
raise ValueError(
199+
f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')."
200+
)
201+
202+
return splits_file_path
203+
159204
def extract_class_hierarchy(self, chebi_path):
160205
"""
161206
Extracts the class hierarchy from the ChEBI ontology.
@@ -632,7 +677,7 @@ def prepare_data(self, *args, **kwargs):
632677
# Generate the "chebi_version_train" data if it doesn't exist
633678
self._chebi_version_train_obj.prepare_data(*args, **kwargs)
634679

635-
def _get_dynamic_splits(self):
680+
def _generate_dynamic_splits(self):
636681
"""Generate data splits during run-time and saves in class variables"""
637682

638683
# Load encoded data derived from "chebi_version"
@@ -687,10 +732,43 @@ def _get_dynamic_splits(self):
687732
)
688733
df_test = df_test_chebi_ver
689734

735+
# Generate splits.csv file to store ids of each corresponding split
736+
split_assignment_list: List[pd.DataFrame] = [
737+
pd.DataFrame({"id": df_train["ident"], "split": "train"}),
738+
pd.DataFrame({"id": df_val["ident"], "split": "validation"}),
739+
pd.DataFrame({"id": df_test["ident"], "split": "test"}),
740+
]
741+
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True)
742+
combined_split_assignment.to_csv(
743+
os.path.join(self.processed_dir_main, "splits.csv")
744+
)
745+
746+
# Store the splits in class variables
690747
self.dynamic_df_train = df_train
691748
self.dynamic_df_val = df_val
692749
self.dynamic_df_test = df_test
693750

751+
def _retreive_splits_from_csv(self):
752+
splits_df = pd.read_csv(self.splits_file_path)
753+
754+
filename = self.processed_file_names_dict["data"]
755+
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
756+
df_chebi_version = pd.DataFrame(data_chebi_version)
757+
758+
train_ids = splits_df[splits_df["split"] == "train"]["id"]
759+
validation_ids = splits_df[splits_df["split"] == "validation"]["id"]
760+
test_ids = splits_df[splits_df["split"] == "test"]["id"]
761+
762+
self.dynamic_df_train = df_chebi_version[
763+
df_chebi_version["ident"].isin(train_ids)
764+
]
765+
self.dynamic_df_val = df_chebi_version[
766+
df_chebi_version["ident"].isin(validation_ids)
767+
]
768+
self.dynamic_df_test = df_chebi_version[
769+
df_chebi_version["ident"].isin(test_ids)
770+
]
771+
694772
@property
695773
def dynamic_split_dfs(self):
696774
if any(
@@ -701,7 +779,10 @@ def dynamic_split_dfs(self):
701779
self.dynamic_df_train,
702780
]
703781
):
704-
self._get_dynamic_splits()
782+
if self.splits_file_path is None:
783+
self._generate_dynamic_splits()
784+
else:
785+
self._retreive_splits_from_csv()
705786
return {
706787
"train": self.dynamic_df_train,
707788
"validation": self.dynamic_df_val,

0 commit comments

Comments
 (0)