@@ -112,7 +112,7 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
112112 labels .add (os .path .basename (os .path .dirname (downloaded_dir_file )))
113113 path_depths .add (count_path_segments (downloaded_dir_file ))
114114 elif os .path .basename (downloaded_dir_file ) in metadata_filenames :
115- metadata_files [split ].add ((None , downloaded_dir_file ))
115+ metadata_files [split ].add ((None , downloaded_dir , downloaded_dir_file ))
116116 else :
117117 archive_file_name = os .path .basename (archive )
118118 original_file_name = os .path .basename (downloaded_dir_file )
@@ -123,8 +123,6 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
123123 data_files = self .config .data_files
124124 splits = []
125125 for split_name , files in data_files .items ():
126- if isinstance (files , str ):
127- files = [files ]
128126 files , archives = self ._split_files_and_archives (files )
129127 downloaded_files = dl_manager .download (files )
130128 downloaded_dirs = dl_manager .download_and_extract (archives )
@@ -156,12 +154,17 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
156154 else :
157155 add_labels , add_metadata , metadata_files = False , False , {}
158156
157+ # files info (original_file, downloaded_file)
158+ files = tuple (zip (files , downloaded_files ))
159+ # dirs info (original_file, downloaded_dir, downloaded_files)
160+ files += tuple (
161+ (None , downloaded_dir , dl_manager .iter_files (downloaded_dir )) for downloaded_dir in downloaded_dirs
162+ )
159163 splits .append (
160164 datasets .SplitGenerator (
161165 name = split_name ,
162166 gen_kwargs = {
163- "files" : tuple (zip (files , downloaded_files ))
164- + tuple ((None , dl_manager .iter_files (downloaded_dir )) for downloaded_dir in downloaded_dirs ),
167+ "files" : files ,
165168 "metadata_files" : metadata_files .get (split_name , []),
166169 "add_labels" : add_labels ,
167170 "add_metadata" : add_metadata ,
@@ -267,7 +270,7 @@ def _split_files_and_archives(self, data_files):
267270 files .append (data_file )
268271 elif os .path .basename (data_file ) in metadata_filenames :
269272 files .append (data_file )
270- else :
273+ elif data_file_ext . lower () == ".zip" :
271274 archives .append (data_file )
272275 return files , archives
273276
@@ -354,6 +357,14 @@ def _read_metadata(self, metadata_file: str, metadata_ext: str = "") -> Iterator
354357 ):
355358 yield pa .Table .from_batches ([record_batch ])
356359
360+ def _generate_shards (self , files , metadata_files , add_metadata , add_labels ):
361+ if add_metadata :
362+ for original_metadata_file , downloaded_metadata_file in metadata_files :
363+ yield downloaded_metadata_file
364+ else :
365+ for original_file , downloaded_file_or_dir in files :
366+ yield downloaded_file_or_dir
367+
357368 def _generate_examples (self , files , metadata_files , add_metadata , add_labels ):
358369 if add_metadata :
359370 feature_paths = []
@@ -365,7 +376,11 @@ def find_feature_path(feature, feature_path):
365376
366377 _visit_with_path (self .info .features , find_feature_path )
367378
368- for shard_idx , (original_metadata_file , downloaded_metadata_file ) in enumerate (metadata_files ):
379+ for shard_idx , metadata_file_info in enumerate (metadata_files ):
380+ if len (metadata_file_info ) == 2 :
381+ original_metadata_file , downloaded_metadata_file = metadata_file_info
382+ else :
383+ original_metadata_file , downloaded_metadata_dir , downloaded_metadata_file = metadata_file_info
369384 metadata_ext = os .path .splitext (original_metadata_file or downloaded_metadata_file )[- 1 ]
370385 downloaded_metadata_dir = os .path .dirname (downloaded_metadata_file )
371386
@@ -395,12 +410,13 @@ def set_feature(item, feature_path: _VisitPath):
395410 if isinstance (self .config .filters , list )
396411 else self .config .filters
397412 )
398- for shard_idx , (original_file , downloaded_file_or_dir ) in enumerate (files ):
399- downloaded_files = [downloaded_file_or_dir ] if original_file else downloaded_file_or_dir
413+ for shard_idx , file_or_dir_info in enumerate (files ):
414+ if len (file_or_dir_info ) == 2 :
415+ original_file , downloaded_file = file_or_dir_info
416+ downloaded_files = [downloaded_file ]
417+ else :
418+ original_file , downloaded_dir , downloaded_files = file_or_dir_info
400419 for sample_idx , downloaded_file in enumerate (downloaded_files ):
401- original_file_ext = os .path .splitext (original_file or downloaded_file )[- 1 ]
402- if original_file_ext .lower () not in self .EXTENSIONS :
403- continue
404420 sample = {self .BASE_COLUMN_NAME : downloaded_file }
405421 if add_labels :
406422 sample ["label" ] = os .path .basename (os .path .dirname (original_file or downloaded_file ))
0 commit comments