11from datasets import Dataset
22
3- from eval_framework .tasks .base import SubjectType
3+ from eval_framework .tasks .base import NO_SUBJECT , SubjectType
44from eval_framework .tasks .benchmarks .copa import COPA
55
66
@@ -14,9 +14,6 @@ def split_dataset_by_id_ranges(
1414 id_column: The name of the column containing the id values.
1515 ranges: A list of (low, high) tuples defining inclusive ranges.
1616 Rows whose id is within any of these ranges go into the first split.
17-
18- Returns:
19-
2017 """
2118
2219 def in_any_range (id_value : int ) -> bool :
@@ -36,13 +33,22 @@ class BalancedCOPA(COPA):
3633 HF_REVISION = "813bd03cd6e07d9bd8d7333896ad5d40abb95ea9"
3734 SUBJECTS = ["no_subject" ]
3835
39- def _resplit_dataset_into_train_and_val (self ) -> None :
36+ def _split_dataset_into_train_and_val (self , dataset ) -> None :
4037 # We split the train data into train and validation splits so that
4138 # the validation split matches the validation split of the original COPA dataset.
42- self .dataset ["train" ], self .dataset ["validation" ] = split_dataset_by_id_ranges (
43- self .dataset ["train" ], "id" , [(401 , 500 ), (1401 , 1500 )]
39+ # These magic numbers of the ids below were arrived at after manual inspection of the dataset.
40+ # The sanity of this version is maintained by the HF_REVISION above.
41+ dataset ["validation" ], dataset ["train" ] = split_dataset_by_id_ranges (
42+ dataset ["train" ], "id" , [(401 , 500 ), (1401 , 1500 )]
4443 )
44+ return dataset
4545
4646 def _load_dataset (self , subject : SubjectType ) -> None :
47- super ()._load_dataset (subject )
48- self ._resplit_dataset_into_train_and_val ()
47+ # This method largely reimplements the _load_dataset method in the base class,
48+ # as the _shuffle_splits method drops any column not in FEWSHOT_SPLIT, SAMPLE_SPLIT.
49+ # Thus, we need to split the dataset into train and validation splits before shuffling.
50+ name = subject if subject != NO_SUBJECT else None
51+ hf_dataset = self ._load_hf_dataset (path = self .DATASET_PATH , name = name )
52+ hf_dataset = self ._split_dataset_into_train_and_val (hf_dataset )
53+
54+ self .dataset = self ._shuffle_splits (hf_dataset = hf_dataset )
0 commit comments