22import copy
33import inspect
44import itertools
5+ import multiprocessing .pool
56import sys
67from collections import Counter
78from collections .abc import Iterable , Iterator
2425 Value ,
2526 _align_features ,
2627 _check_if_features_can_be_aligned ,
28+ _visit ,
2729 cast_to_python_objects ,
2830)
2931from .formatting import (
@@ -1010,6 +1012,7 @@ def __init__(
10101012 fn_kwargs : Optional [dict ] = None ,
10111013 formatting : Optional ["FormattingConfig" ] = None ,
10121014 features : Optional [Features ] = None ,
1015+ max_num_running_async_map_functions_in_parallel : Optional [int ] = None ,
10131016 ):
10141017 super ().__init__ ()
10151018 self .ex_iterable = ex_iterable
@@ -1023,6 +1026,9 @@ def __init__(
10231026 self .fn_kwargs = fn_kwargs or {}
10241027 self .formatting = formatting # required for iter_arrow
10251028 self ._features = features
1029+ self .max_num_running_async_map_functions_in_parallel = (
1030+ max_num_running_async_map_functions_in_parallel or config .MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL
1031+ )
10261032 # sanity checks
10271033 if formatting and formatting .is_table :
10281034 # batch_size should match for iter_arrow
@@ -1036,6 +1042,8 @@ def __init__(
10361042 f"The { formatting .format_type .capitalize ()} -formatted { type (self ).__name__ } has batch_size={ batch_size if batched else 1 } which is"
10371043 f"different from { ex_iterable .batch_size = } from its underlying iterable."
10381044 )
1045+ # to enable graceful ends
1046+ self ._owned_loops_and_tasks : list [tuple [asyncio .AbstractEventLoop , list [asyncio .Task ]]] = []
10391047
10401048 @property
10411049 def iter_arrow (self ):
@@ -1174,6 +1182,7 @@ async def async_apply_function(key_example, indices):
11741182 loop = asyncio .get_running_loop ()
11751183 except RuntimeError :
11761184 loop = asyncio .new_event_loop ()
1185+ self ._owned_loops_and_tasks .append ((loop , tasks ))
11771186 else :
11781187 loop = None
11791188
@@ -1191,15 +1200,15 @@ def iter_outputs():
11911200 indices .append (i )
11921201 tasks .append (loop .create_task (async_apply_function (key_example , i )))
11931202 # keep the total active tasks under a certain number
1194- if len (tasks ) >= config . MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1203+ if len (tasks ) >= self . max_num_running_async_map_functions_in_parallel :
11951204 done , pending = loop .run_until_complete (
11961205 asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
11971206 )
1198- while tasks and len (pending ) >= config . MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1207+ while tasks and len (pending ) >= self . max_num_running_async_map_functions_in_parallel :
11991208 done , pending = loop .run_until_complete (
12001209 asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
12011210 )
1202- if len (tasks ) >= 10 * config . MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1211+ if len (tasks ) >= 10 * self . max_num_running_async_map_functions_in_parallel :
12031212 loop .run_until_complete (tasks [0 ])
12041213 # yield finished tasks
12051214 while tasks and tasks [0 ].done ():
@@ -1257,7 +1266,7 @@ def iter_outputs():
12571266 task .cancel (msg = "KeyboardInterrupt" )
12581267 try :
12591268 loop .run_until_complete (asyncio .gather (* tasks ))
1260- except asyncio .CancelledError :
1269+ except ( asyncio .CancelledError , ValueError ) :
12611270 logger .debug ("Tasks canceled." )
12621271 raise
12631272
@@ -1347,6 +1356,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample
13471356 fn_kwargs = self .fn_kwargs ,
13481357 formatting = self .formatting ,
13491358 features = self .features ,
1359+ max_num_running_async_map_functions_in_parallel = self .max_num_running_async_map_functions_in_parallel ,
13501360 )
13511361
13521362 def shard_data_sources (self , num_shards : int , index : int , contiguous = True ) -> "MappedExamplesIterable" :
@@ -1363,6 +1373,7 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "M
13631373 fn_kwargs = self .fn_kwargs ,
13641374 formatting = self .formatting ,
13651375 features = self .features ,
1376+ max_num_running_async_map_functions_in_parallel = self .max_num_running_async_map_functions_in_parallel ,
13661377 )
13671378
13681379 @property
@@ -3189,6 +3200,99 @@ def cast(
31893200 token_per_repo_id = self ._token_per_repo_id ,
31903201 )
31913202
3203+ def decode (self , enable : bool = True , num_threads : int = 0 ) -> "IterableDataset" :
3204+ """
3205+ Enable or disable the dataset features decoding for audio, image, video.
3206+
3207+ When enabled (default), media types are decoded:
3208+
3209+ * audio -> dict of "array" and "sampling_rate" and "path"
3210+ * image -> PIL.Image
3211+ * video -> torchvision.io.VideoReader
3212+
3213+ You can enable multithreading using `num_threads`. This is especially useful to speed up remote
3214+ data streaming. However it can be slower than `num_threads=0` for local data on fast disks.
3215+
3216+ Disabling decoding is useful if you want to iterate on the paths or bytes of the media files
3217+ without actually decoding their content. To disable decoding you can use `.decode(False)`, which
3218+ is equivalent to calling `.cast()` or `.cast_column()` with all the Audio, Image and Video types
3219+ set to `decode=False`.
3220+
3221+ Args:
3222+ enable (`bool`, defaults to `True`):
3223+ Enable or disable features decoding.
3224+ num_threads (`int`, defaults to `0`):
3225+ Enable multithreading for features decoding.
3226+
3227+ Returns:
3228+ `IterableDataset`: A copy of the dataset with casted features.
3229+
3230+ Examples:
3231+
3232+ Disable decoding:
3233+
3234+ ```py
3235+ >>> from datasets import load_dataset
3236+ >>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
3237+ >>> next(iter(ds))
3238+ {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024>,
3239+ 'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'}
3240+ >>> ds = ds.decode(False)
3241+ >>> ds.features
3242+ {'image': Image(mode=None, decode=False, id=None),
3243+ 'text': Value(dtype='string', id=None)}
3244+ >>> next(iter(ds))
3245+ {
3246+ 'image': {
3247+ 'path': 'hf://datasets/sshh12/planet-textures@69dc4cef7a5c4b2cfe387727ec8ea73d4bff7302/train/textures/0000.png',
3248+ 'bytes': None
3249+ },
3250+ 'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'
3251+ }
3252+ ```
3253+
3254+ Speed up streaming with multithreading:
3255+
3256+ ```py
3257+ >>> import os
3258+ >>> from datasets import load_dataset
3259+ >>> from tqdm import tqdm
3260+ >>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
3261+ >>> num_threads = min(32, (os.cpu_count() or 1) + 4)
3262+ >>> ds = ds.decode(num_threads=num_threads)
3263+ >>> for _ in tqdm(ds): # 20 times faster !
3264+ ... ...
3265+ ```
3266+ """
3267+ if not self .features :
3268+ raise ValueError (
3269+ "Features decoding is only available for datasets with known features, but features are Unknown. "
3270+ "Please set the datasets features with `ds = ds.cast(features)`."
3271+ )
3272+ ds = self
3273+
3274+ def set_decoding (decode : bool , feature ):
3275+ if hasattr (feature , "decode" ):
3276+ feature .decode = decode
3277+
3278+ if enable and num_threads > 0 :
3279+ disabled_decoding_features = self .features .copy ()
3280+ enabled_decoding_features = self .features .copy ()
3281+
3282+ _visit (disabled_decoding_features , partial (set_decoding , False ))
3283+ _visit (enabled_decoding_features , partial (set_decoding , True ))
3284+ ds = ds .cast (disabled_decoding_features )
3285+ pool = multiprocessing .pool .ThreadPool (num_threads )
3286+ func = partial (_apply_async , pool , enabled_decoding_features .decode_example )
3287+ ds = ds .map (func , features = enabled_decoding_features )
3288+ assert isinstance (ds ._ex_iterable , MappedExamplesIterable )
3289+ ds ._ex_iterable .max_num_running_async_map_functions_in_parallel = 2 * num_threads
3290+ else :
3291+ features = ds .features .copy ()
3292+ _visit (features , partial (set_decoding , enable ))
3293+ ds = ds .cast (features )
3294+ return ds
3295+
31923296 def _step (self , step : int , offset : int ) -> "IterableDataset" :
31933297 ex_iterable = StepExamplesIterable (self ._ex_iterable , step = step , offset = offset )
31943298 return IterableDataset (
@@ -3407,3 +3511,12 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
34073511 distributed = distributed ,
34083512 token_per_repo_id = dataset ._token_per_repo_id ,
34093513 )
3514+
3515+
3516+ async def _apply_async (pool , func , x ):
3517+ future = pool .apply_async (func , (x ,))
3518+ while True :
3519+ if future .ready ():
3520+ return future .get ()
3521+ else :
3522+ await asyncio .sleep (0 )
0 commit comments