@@ -143,7 +143,9 @@ def _fix_streaming_keys(row):
143143 new_k = k [len ('__@' ):]
144144 row [new_k ] = row .pop (k )
145145
146- def batched_preprocess (self , batched_row : Dict [str , Any ], * , strict : bool ) -> Dict [str , Any ]:
146+ def batched_preprocess (self , batched_row : Dict [str , Any ], * , strict : bool ,
147+ ignore_max_length_error : bool ) -> Dict [str , Any ]:
148+ from ...template import MaxLengthError
147149 batched_row = dict (batched_row )
148150 assert len (batched_row ) > 0
149151 self ._fix_streaming_keys (batched_row )
@@ -162,13 +164,15 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di
162164 self ._check_messages (r )
163165 self ._check_rejected_response (r )
164166 self ._cast_images (r )
165- except Exception :
167+ except Exception as e :
166168 if strict :
167169 logger .warning ('To avoid errors, you can pass `strict=False`.' )
168170 raise
169- if self .traceback_limit is not None and self ._traceback_counter < self .traceback_limit :
171+ if isinstance (e , MaxLengthError ) and ignore_max_length_error :
172+ pass
173+ elif self .traceback_limit is not None and self ._traceback_counter < self .traceback_limit :
170174 import traceback
171- print (traceback .format_exc ())
175+ logger . info (traceback .format_exc ())
172176 logger .error ('👆👆👆There are errors in the dataset, the data will be deleted' )
173177 self ._traceback_counter += 1
174178 row = []
@@ -256,15 +260,21 @@ def __call__(
256260 dataset = self .prepare_dataset (dataset )
257261 dataset = self ._cast_pil_image (dataset )
258262 map_kwargs = {}
263+ ignore_max_length_error = False
259264 if isinstance (dataset , HfDataset ):
260265 map_kwargs ['num_proc' ] = num_proc
266+ if num_proc > 1 :
267+ ignore_max_length_error = True
261268 with self ._patch_arrow_writer ():
262269 try :
263270 dataset_mapped = dataset .map (
264271 self .batched_preprocess ,
265272 batched = True ,
266273 batch_size = batch_size ,
267- fn_kwargs = {'strict' : strict },
274+ fn_kwargs = {
275+ 'strict' : strict ,
276+ 'ignore_max_length_error' : ignore_max_length_error
277+ },
268278 remove_columns = list (dataset .features .keys ()),
269279 ** map_kwargs )
270280 except NotImplementedError :
0 commit comments