1- import itertools
21from dataclasses import dataclass
32from typing import Any , Callable , Optional , Union
43
@@ -154,13 +153,17 @@ def _split_generators(self, dl_manager):
154153 if not self .config .data_files :
155154 raise ValueError (f"At least one data file must be specified, but got data_files={ self .config .data_files } " )
156155 dl_manager .download_config .extract_on_the_fly = True
157- data_files = dl_manager .download_and_extract (self .config .data_files )
156+ base_data_files = dl_manager .download (self .config .data_files )
157+ extracted_data_files = dl_manager .extract (base_data_files )
158158 splits = []
159- for split_name , files in data_files .items ():
160- if isinstance (files , str ):
161- files = [files ]
162- files = [dl_manager .iter_files (file ) for file in files ]
163- splits .append (datasets .SplitGenerator (name = split_name , gen_kwargs = {"files" : files }))
159+ for split_name , extracted_files in extracted_data_files .items ():
160+ files_iterables = [dl_manager .iter_files (extracted_file ) for extracted_file in extracted_files ]
161+ splits .append (
162+ datasets .SplitGenerator (
163+ name = split_name ,
164+ gen_kwargs = {"files_iterables" : files_iterables , "base_files" : base_data_files [split_name ]},
165+ )
166+ )
164167 return splits
165168
166169 def _cast_table (self , pa_table : pa .Table ) -> pa .Table :
@@ -174,7 +177,10 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
174177 pa_table = table_cast (pa_table , schema )
175178 return pa_table
176179
177- def _generate_tables (self , files ):
180+ def _generate_shards (self , base_files , files_iterables ):
181+ yield from base_files
182+
183+ def _generate_tables (self , base_files , files_iterables ):
178184 schema = self .config .features .arrow_schema if self .config .features else None
179185 # dtype allows reading an int column as str
180186 dtype = (
@@ -185,15 +191,16 @@ def _generate_tables(self, files):
185191 if schema is not None
186192 else None
187193 )
188- for file_idx , file in enumerate (itertools .chain .from_iterable (files )):
189- csv_file_reader = pd .read_csv (file , iterator = True , dtype = dtype , ** self .config .pd_read_csv_kwargs )
190- try :
191- for batch_idx , df in enumerate (csv_file_reader ):
192- pa_table = pa .Table .from_pandas (df )
193- # Uncomment for debugging (will print the Arrow table size and elements)
194- # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
195- # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
196- yield Key (file_idx , batch_idx ), self ._cast_table (pa_table )
197- except ValueError as e :
198- logger .error (f"Failed to read file '{ file } ' with error { type (e )} : { e } " )
199- raise
194+ for shard_idx , files_iterable in enumerate (files_iterables ):
195+ for file in files_iterable :
196+ csv_file_reader = pd .read_csv (file , iterator = True , dtype = dtype , ** self .config .pd_read_csv_kwargs )
197+ try :
198+ for batch_idx , df in enumerate (csv_file_reader ):
199+ pa_table = pa .Table .from_pandas (df )
200+ # Uncomment for debugging (will print the Arrow table size and elements)
201+ # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
202+ # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
203+ yield Key (shard_idx , batch_idx ), self ._cast_table (pa_table )
204+ except ValueError as e :
205+ logger .error (f"Failed to read file '{ file } ' with error { type (e )} : { e } " )
206+ raise
0 commit comments