@@ -101,6 +101,15 @@ def task_generator(
101101 yield task
102102
103103
104+ def load_hf_dataset (config : DataConfig , split : str ):
105+ """Load a Hugging Face dataset with optional configuration name."""
106+ if config .subset_name is not None :
107+ hf_dataset = load_dataset (config .dataset_path , config .subset_name , split = split )
108+ else :
109+ hf_dataset = load_dataset (config .dataset_path , split = split )
110+ return hf_dataset
111+
112+
104113@dataclass
105114class TaskSet :
106115 """A TaskSet class that defines a set of tasks and their associated reward functions."""
@@ -125,7 +134,8 @@ def load(
125134 # disable datasets caching to avoid reuse old-version dataset
126135 datasets .disable_caching ()
127136 if task_type == TaskType .EVAL :
128- dataset = load_dataset (config .dataset_path )[config .eval_split ]
137+ assert config .eval_split is not None , "eval_split must be provided for eval taskset."
138+ dataset = load_hf_dataset (config , config .eval_split )
129139 else : # default
130140 if task_type != TaskType .EVAL and config .db_url != "" :
131141 logger .info (f"Loading dataset from database with url: { config .db_url } " )
@@ -134,7 +144,7 @@ def load(
134144 dataset = Dataset .from_sql (RftDatasetModel .__tablename__ , f"{ db_type } :///{ db_name } " )
135145 elif config .dataset_path != "" :
136146 logger .info (f"Loading dataset from local file with path: { config .dataset_path } ." )
137- dataset = load_dataset (config . dataset_path )[ config .train_split ]
147+ dataset = load_hf_dataset (config , config .train_split )
138148 else :
139149 raise ValueError ("No dataset path or db url provided." )
140150 datasets .enable_caching ()
0 commit comments