Skip to content

Commit 07340cb

Browse files
committed
read only first row to validate presence of relevant columns in csv
1 parent 9992a15 commit 07340cb

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __init__(
155155
single_class=self.single_class,
156156
**_init_kwargs,
157157
)
158-
158+
# Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, validation or test).
159159
self.splits_file_path = self._validate_splits_file_path(
160160
kwargs.get("splits_file_path", None)
161161
)
@@ -189,8 +189,8 @@ def _validate_splits_file_path(splits_file_path=None):
189189
if not splits_file_path.lower().endswith(".csv"):
190190
raise ValueError(f"File {splits_file_path} is not a CSV file")
191191

192-
# Read the CSV file into a DataFrame
193-
splits_df = pd.read_csv(splits_file_path)
192+
# Read the first row of CSV file into a DataFrame
193+
splits_df = pd.read_csv(splits_file_path, nrows=1)
194194

195195
# Check if 'id' and 'split' columns are in the DataFrame
196196
required_columns = {"id", "split"}
@@ -604,7 +604,7 @@ def prepare_data(self, *args, **kwargs):
604604
Prepares the data for the Chebi dataset.
605605
606606
This method checks for the presence of raw data in the specified directory.
607-
If the raw data is missing, it fetches the ontology and creates a test test set.
607+
If the raw data is missing, it fetches the ontology and creates a test set.
608608
If the test set already exists, it loads it from the file.
609609
Then, it creates the train/validation split based on the test set.
610610
@@ -780,8 +780,10 @@ def dynamic_split_dfs(self):
780780
]
781781
):
782782
if self.splits_file_path is None:
783+
# Generate splits based on given seed, create csv file to records the splits
783784
self._generate_dynamic_splits()
784785
else:
786+
# If user has provided splits file path, use it to get the splits from the data
785787
self._retreive_splits_from_csv()
786788
return {
787789
"train": self.dynamic_df_train,

0 commit comments

Comments
 (0)