@@ -44,6 +44,8 @@ def normalize_prompts(row) -> List[str]:
4444 "" ,
4545 )
4646 prompts .append (prompt )
47+ else : # we cannot handle this type
48+ continue
4749 elif isinstance (row , str ): # if the row is already a prompt
4850 prompts .append (row )
4951 elif (
@@ -78,109 +80,112 @@ def load_datasets(
7880 Return:
7981 (List[Dataset], str): A list of Dataset objects and the sorting strategy.
8082 """
83+ if datasets_config_file is None :
84+ logger .error ("Customized data loading logic needs to be implemented!" )
85+ return [], None
86+
8187 # Load datasets configuration file
82- if datasets_config_file :
83- try :
84- with open (datasets_config_file , "r" ) as f :
85- datasets_config = json .load (f )
86- except FileNotFoundError :
87- logger .error (
88- f"Configuration file '{ datasets_config_file } ' not found"
89- )
90- return [], None
91- except Exception as e :
92- logger .error (f"Error reading '{ datasets_config_file } ': { e } " )
93- return [], None
94-
95- # Strategy to sort the provided datasets
96- sort_strategy = datasets_config .pop ("sort" , "random" )
97-
98- # List to store each Dataset
99- datasets = []
88+ try :
89+ with open (datasets_config_file , "r" ) as f :
90+ datasets_config = json .load (f )
91+ except FileNotFoundError :
92+ logger .error (
93+ f"Configuration file '{ datasets_config_file } ' not found"
94+ )
95+ return [], None
96+ except Exception as e :
97+ logger .error (f"Error reading '{ datasets_config_file } ': { e } " )
98+ return [], None
10099
101- for name , config in datasets_config .items ():
102- file_name = config .get ("file_name" )
103- prompt_field = config .get ("prompt_field" )
100+ # Strategy to sort the provided datasets
101+ sort_strategy = datasets_config .pop ("sort_strategy" , "random" )
104102
105- try :
106- ratio = int (config .get ("select_ratio" , 1 ))
107- except ValueError :
108- logger .error (
109- f"Invalid 'select_ratio' for dataset '{ name } ', using default 1"
110- )
111- ratio = 1
103+ # List to store each Dataset
104+ datasets = []
112105
113- if not file_name or not prompt_field :
114- logger .error (
115- f"Missing required 'file_name' or 'prompt_field' for dataset '{ name } '"
116- )
117- continue
106+ for name , config in datasets_config .items ():
107+ file_name = config .get ("file_name" )
108+ prompt_field = config .get ("prompt_field" )
118109
119- file_path = (
120- os .path .abspath (file_name )
121- if os .path .exists (file_name )
122- else os .path .join (DEFAULT_DATASET_FOLDER , file_name )
110+ try :
111+ ratio = int (config .get ("select_ratio" , 1 ))
112+ except ValueError :
113+ logger .error (
114+ f"Invalid 'select_ratio' for dataset '{ name } ', using default 1"
123115 )
116+ ratio = 1
124117
125- # Load dataset from local files
126- if os .path .exists (file_path ):
127- prompts = []
128- # CSV files
129- if file_name .endswith (".csv" ):
130- data = pd .read_csv (file_path )
118+ if not file_name or not prompt_field :
119+ logger .error (
120+ f"Missing required 'file_name' or 'prompt_field' for dataset '{ name } '"
121+ )
122+ continue
123+
124+ os .makedirs (DEFAULT_DATASET_FOLDER , exist_ok = True )
125+
126+ file_path = (
127+ os .path .abspath (file_name )
128+ if os .path .exists (file_name )
129+ else os .path .join (DEFAULT_DATASET_FOLDER , file_name )
130+ )
131131
132- if prompt_field not in set (data .columns ):
132+ # Load dataset from local files
133+ if os .path .exists (file_path ):
134+ prompts = []
135+ # CSV files
136+ if file_name .endswith (".csv" ):
137+ data = pd .read_csv (file_path )
138+
139+ if prompt_field not in set (data .columns ):
140+ logger .error (
141+ f"Field '{ prompt_field } ' not found in '{ file_path } '."
142+ )
143+ continue
144+ prompts = data [prompt_field ].dropna ().astype (str ).tolist ()
145+ # JSON files
146+ elif file_name .endswith (".json" ):
147+ with open (file_path , "r" ) as f :
148+ data = json .load (f )
149+
150+ if isinstance (data , dict ):
151+ prompts = data .get (prompt_field , [])
152+ if not isinstance (prompts , list ):
133153 logger .error (
134- f"Field '{ prompt_field } ' not found in '{ file_path } '."
154+ f"Field '{ prompt_field } ' in '{ file_path } ' is not a list ."
135155 )
136156 continue
137- prompts = data [prompt_field ].dropna ().astype (str ).tolist ()
138- # JSON files
139- elif file_name .endswith (".json" ):
140- with open (file_path , "r" ) as f :
141- data = json .load (f )
142-
143- if isinstance (data , dict ):
144- prompts = data .get (prompt_field , [])
145- if not isinstance (prompts , list ):
146- logger .error (
147- f"Field '{ prompt_field } ' in '{ file_path } ' is not a list."
148- )
149- continue
150- else :
151- logger .error (f"Unsupported file format for '{ file_name } '" )
152- continue
153157 else :
154- try :
155- if file_name .endswith (".csv" ): # CSV format
156- data = pd .read_csv (file_name )
157-
158- if prompt_field not in set (data .columns ):
159- logger .error (
160- f"Field '{ prompt_field } ' not found in '{ file_name } '."
161- )
162- continue
163- prompts = (
164- data [prompt_field ].dropna ().astype (str ).tolist ()
158+ logger .error (f"Unsupported file format for '{ file_name } '" )
159+ continue
160+ else :
161+ try :
162+ if file_name .endswith (".csv" ): # CSV format
163+ data = pd .read_csv (file_name )
164+
165+ if prompt_field not in set (data .columns ):
166+ logger .error (
167+ f"Field '{ prompt_field } ' not found in '{ file_name } '."
165168 )
166- else : # use datasets to load
167- data = load_dataset (file_name )["train" ]
168- prompts = []
169- for row in data [prompt_field ]:
170- prompts .extend (normalize_prompts (row ))
171- except Exception as e :
172- logger .error (f"Failed to load '{ file_name } ': { e } " )
173-
174- # Add the dataset information (file name, a list of prompts, select ratio among all datasets, total number of prompts)
175- dataset_obj = Dataset (file_name , prompts , ratio , len (prompts ))
176- datasets .append (dataset_obj )
177-
178- logger .info (
179- f"loaded { file_name } with { len (prompts )} prompts, selection ratio = { ratio } "
180- )
169+ continue
170+ prompts = (
171+ data [prompt_field ].dropna ().astype (str ).tolist ()
172+ )
173+ else : # use datasets to load
174+ data = load_dataset (file_name )["train" ]
175+ prompts = []
176+ for row in data [prompt_field ]:
177+ prompts .extend (normalize_prompts (row ))
178+ except Exception as e :
179+ logger .error (f"Failed to load '{ file_name } ': { e } " )
180+
181+ # Add the dataset information (file name, a list of prompts, select ratio among all datasets, total number of prompts)
182+ dataset_obj = Dataset (file_name , prompts , ratio , len (prompts ))
183+ datasets .append (dataset_obj )
184+
185+ logger .info (
186+ f"loaded { file_name } with { len (prompts )} prompts, selection ratio = { ratio } "
187+ )
181188
182- return datasets , sort_strategy
189+ return datasets , sort_strategy
183190
184- else :
185- logger .error ("Customized data loading logic needs to be implemented!" )
186- return [], None
191+
0 commit comments