From 7a62630499849a3fec8ec10d15786a9009d7430e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 3 Feb 2025 19:16:54 +0100 Subject: [PATCH 01/10] support async functions in map() --- src/datasets/arrow_dataset.py | 98 ++++++++++++++++++++++++++++++++++- src/datasets/config.py | 3 ++ 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 927bc01709f..4c146bc7c02 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -15,9 +15,11 @@ # Lint as: python3 """Simple Dataset wrapping an Arrow Table.""" +import asyncio import contextlib import copy import fnmatch +import inspect import itertools import json import math @@ -3383,6 +3385,73 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example else: return processed_inputs + async def async_apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0): + """Utility to apply the function on a selection of columns. Same code but async""" + nonlocal update_data + inputs = format_table( + pa_inputs, + 0 if not batched else range(pa_inputs.num_rows), + format_columns=input_columns, + formatter=input_formatter, + ) + fn_args = [inputs] if input_columns is None else [inputs[col] for col in input_columns] + if offset == 0: + effective_indices = indices + else: + effective_indices = [i + offset for i in indices] if isinstance(indices, list) else indices + offset + additional_args = () + if with_indices: + additional_args += (effective_indices,) + if with_rank: + additional_args += (rank,) + processed_inputs = await function(*fn_args, *additional_args, **fn_kwargs) + if isinstance(processed_inputs, LazyDict): + processed_inputs = { + k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format + } + returned_lazy_dict = True + else: + returned_lazy_dict = False + if update_data is None: + # Check if the function returns updated examples + updatable_types = (Mapping, pa.Table, pd.DataFrame) + if config.POLARS_AVAILABLE and "polars" in sys.modules: + import polars as pl + + updatable_types += (pl.DataFrame,) + update_data = isinstance(processed_inputs, updatable_types) + validate_function_output(processed_inputs, indices) + if not update_data: + return None # Nothing to update, let's move on + if shard._format_type or input_columns: + # TODO(QL, MS): ideally the behavior should be the same even if the dataset is formatted (may require major release) + inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns())) + elif isinstance(inputs, LazyDict): + inputs_to_merge = { + k: (v if k not in inputs.keys_to_format else pa_inputs[k]) for k, v in inputs.data.items() + } + else: + inputs_to_merge = inputs + if remove_columns is not None: + for column in remove_columns: + # `function` can modify input in-place causing column to be already removed. + if column in inputs_to_merge: + inputs_to_merge.pop(column) + if returned_lazy_dict and column in processed_inputs: + processed_inputs.pop(column) + if check_same_num_examples: + input_num_examples = len(pa_inputs) + processed_inputs_num_examples = len(processed_inputs[next(iter(processed_inputs.keys()))]) + if input_num_examples != processed_inputs_num_examples: + raise NumExamplesMismatchError() + if isinstance(inputs, Mapping) and isinstance(processed_inputs, Mapping): + # The .map() transform *updates* the dataset: + # the output dictionary contains both the the input data and the output data. + # The output dictionary may contain Arrow values from `inputs_to_merge` so that we can re-write them efficiently. + return {**inputs_to_merge, **processed_inputs} + else: + return processed_inputs + def init_buffer_and_writer(): # Prepare output buffer and batched writer in memory or on file if we update the table writer_features = features @@ -3418,6 +3487,32 @@ def init_buffer_and_writer(): ) return buf_writer, writer, tmp_file + def iter_output_examples(shard_iterable): + if inspect.iscoroutinefunction(function): + indices: List[int] = [] + tasks: List[asyncio.Task] = [] + loop = asyncio.get_event_loop() + for i, example in shard_iterable: + indices.append(i) + tasks.append(loop.create_task(async_apply_function_on_filtered_inputs(example, i, offset=offset))) + # keep the total active tasks under 30 + if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL: + done, pending = loop.run_until_complete( + asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + ) + while tasks and len(pending) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL: + done, pending = loop.run_until_complete( + asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + ) + # yield finished tasks + while tasks and tasks[0].done(): + yield indices.pop(0), tasks.pop(0).result() + while tasks: + yield indices.pop(0), loop.run_until_complete(tasks.pop(0)) + else: + for i, example in shard_iterable: + yield i, apply_function_on_filtered_inputs(example, i, offset=offset) + num_examples_progress_update = 0 # If `update_data` is True after processing the first example/batch, initalize these resources with `init_buffer_and_writer` buf_writer, writer, tmp_file = None, None, None @@ -3442,8 +3537,7 @@ def init_buffer_and_writer(): ) if not batched: _time = time.time() - for i, example in shard_iterable: - example = apply_function_on_filtered_inputs(example, i, offset=offset) + for i, example in iter_output_examples(shard_iterable): if update_data: if i == 0: buf_writer, writer, tmp_file = init_buffer_and_writer() diff --git a/src/datasets/config.py b/src/datasets/config.py index 43801efcaef..47ec8868f17 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -255,6 +255,9 @@ GLOBBED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 10 ARCHIVED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 200 +# Async map functions +MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL = 1000 + # Progress bars PBAR_REFRESH_TIME_INTERVAL = 0.05 # 20 progress updates per sec From 1f6a69ab55709bbaf75bb42f2154bb1e0440aef3 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 11 Feb 2025 19:12:20 +0100 Subject: [PATCH 02/10] simplify code --- src/datasets/arrow_dataset.py | 103 ++++++++-------------------------- 1 file changed, 23 insertions(+), 80 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 4c146bc7c02..bacc5e662a3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3281,7 +3281,7 @@ def _map_single( class NumExamplesMismatchError(Exception): pass - def validate_function_output(processed_inputs, indices): + def validate_function_output(processed_inputs): """Validate output of the map function.""" allowed_processed_inputs_types = (Mapping, pa.Table, pd.DataFrame) if config.POLARS_AVAILABLE and "polars" in sys.modules: @@ -3292,7 +3292,7 @@ def validate_function_output(processed_inputs, indices): raise TypeError( f"Provided `function` which is applied to all elements of table returns a variable of type {type(processed_inputs)}. Make sure provided `function` returns a variable of type `dict` (or a pyarrow table) to update the dataset or `None` if you are only interested in side effects." ) - elif isinstance(indices, list) and isinstance(processed_inputs, Mapping): + if batched and isinstance(processed_inputs, Mapping): allowed_batch_return_types = (list, np.ndarray, pd.Series) if config.POLARS_AVAILABLE and "polars" in sys.modules: import polars as pl @@ -3318,9 +3318,8 @@ def validate_function_output(processed_inputs, indices): f"Provided `function` which is applied to all elements of table returns a `dict` of types {[type(x) for x in processed_inputs.values()]}. When using `batched=True`, make sure provided `function` returns a `dict` of types like `{allowed_batch_return_types}`." ) - def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0): + def prepare_inputs(pa_inputs, indices, offset=0): """Utility to apply the function on a selection of columns.""" - nonlocal update_data inputs = format_table( pa_inputs, 0 if not batched else range(pa_inputs.num_rows), @@ -3337,7 +3336,12 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example additional_args += (effective_indices,) if with_rank: additional_args += (rank,) - processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) + return inputs, fn_args, additional_args, fn_kwargs + + def prepare_outputs(pa_inputs, inputs, processed_inputs, check_same_num_examples=False): + nonlocal update_data + if not (update_data := (processed_inputs is not None)): + return None if isinstance(processed_inputs, LazyDict): processed_inputs = { k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format @@ -3345,17 +3349,7 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example returned_lazy_dict = True else: returned_lazy_dict = False - if update_data is None: - # Check if the function returns updated examples - updatable_types = (Mapping, pa.Table, pd.DataFrame) - if config.POLARS_AVAILABLE and "polars" in sys.modules: - import polars as pl - - updatable_types += (pl.DataFrame,) - update_data = isinstance(processed_inputs, updatable_types) - validate_function_output(processed_inputs, indices) - if not update_data: - return None # Nothing to update, let's move on + validate_function_output(processed_inputs) if shard._format_type or input_columns: # TODO(QL, MS): ideally the behavior should be the same even if the dataset is formatted (may require major release) inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns())) @@ -3385,72 +3379,21 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example else: return processed_inputs + def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0): + """Utility to apply the function on a selection of columns.""" + inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset) + processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) + return prepare_outputs( + pa_inputs, inputs, processed_inputs, check_same_num_examples=check_same_num_examples + ) + async def async_apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0): """Utility to apply the function on a selection of columns. Same code but async""" - nonlocal update_data - inputs = format_table( - pa_inputs, - 0 if not batched else range(pa_inputs.num_rows), - format_columns=input_columns, - formatter=input_formatter, - ) - fn_args = [inputs] if input_columns is None else [inputs[col] for col in input_columns] - if offset == 0: - effective_indices = indices - else: - effective_indices = [i + offset for i in indices] if isinstance(indices, list) else indices + offset - additional_args = () - if with_indices: - additional_args += (effective_indices,) - if with_rank: - additional_args += (rank,) + inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset) processed_inputs = await function(*fn_args, *additional_args, **fn_kwargs) - if isinstance(processed_inputs, LazyDict): - processed_inputs = { - k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format - } - returned_lazy_dict = True - else: - returned_lazy_dict = False - if update_data is None: - # Check if the function returns updated examples - updatable_types = (Mapping, pa.Table, pd.DataFrame) - if config.POLARS_AVAILABLE and "polars" in sys.modules: - import polars as pl - - updatable_types += (pl.DataFrame,) - update_data = isinstance(processed_inputs, updatable_types) - validate_function_output(processed_inputs, indices) - if not update_data: - return None # Nothing to update, let's move on - if shard._format_type or input_columns: - # TODO(QL, MS): ideally the behavior should be the same even if the dataset is formatted (may require major release) - inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns())) - elif isinstance(inputs, LazyDict): - inputs_to_merge = { - k: (v if k not in inputs.keys_to_format else pa_inputs[k]) for k, v in inputs.data.items() - } - else: - inputs_to_merge = inputs - if remove_columns is not None: - for column in remove_columns: - # `function` can modify input in-place causing column to be already removed. - if column in inputs_to_merge: - inputs_to_merge.pop(column) - if returned_lazy_dict and column in processed_inputs: - processed_inputs.pop(column) - if check_same_num_examples: - input_num_examples = len(pa_inputs) - processed_inputs_num_examples = len(processed_inputs[next(iter(processed_inputs.keys()))]) - if input_num_examples != processed_inputs_num_examples: - raise NumExamplesMismatchError() - if isinstance(inputs, Mapping) and isinstance(processed_inputs, Mapping): - # The .map() transform *updates* the dataset: - # the output dictionary contains both the the input data and the output data. - # The output dictionary may contain Arrow values from `inputs_to_merge` so that we can re-write them efficiently. - return {**inputs_to_merge, **processed_inputs} - else: - return processed_inputs + return prepare_outputs( + pa_inputs, inputs, processed_inputs, check_same_num_examples=check_same_num_examples + ) def init_buffer_and_writer(): # Prepare output buffer and batched writer in memory or on file if we update the table @@ -3495,7 +3438,7 @@ def iter_output_examples(shard_iterable): for i, example in shard_iterable: indices.append(i) tasks.append(loop.create_task(async_apply_function_on_filtered_inputs(example, i, offset=offset))) - # keep the total active tasks under 30 + # keep the total active tasks under a certain number if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL: done, pending = loop.run_until_complete( asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) From d9ef602739c1d8f2a1fd3d25da865e89ec12607b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 12 Feb 2025 11:03:36 +0100 Subject: [PATCH 03/10] batched async --- src/datasets/arrow_dataset.py | 49 ++++++++++++----------------------- 1 file changed, 16 insertions(+), 33 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bacc5e662a3..659c4858858 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3278,8 +3278,7 @@ def _map_single( **format_kwargs, ) - class NumExamplesMismatchError(Exception): - pass + check_same_num_examples = batched and len(shard.list_indexes()) > 0 def validate_function_output(processed_inputs): """Validate output of the map function.""" @@ -3338,7 +3337,7 @@ def prepare_inputs(pa_inputs, indices, offset=0): additional_args += (rank,) return inputs, fn_args, additional_args, fn_kwargs - def prepare_outputs(pa_inputs, inputs, processed_inputs, check_same_num_examples=False): + def prepare_outputs(pa_inputs, inputs, processed_inputs): nonlocal update_data if not (update_data := (processed_inputs is not None)): return None @@ -3370,7 +3369,9 @@ def prepare_outputs(pa_inputs, inputs, processed_inputs, check_same_num_examples input_num_examples = len(pa_inputs) processed_inputs_num_examples = len(processed_inputs[next(iter(processed_inputs.keys()))]) if input_num_examples != processed_inputs_num_examples: - raise NumExamplesMismatchError() + raise DatasetTransformationNotAllowedError( + "Using `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn't create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it." + ) from None if isinstance(inputs, Mapping) and isinstance(processed_inputs, Mapping): # The .map() transform *updates* the dataset: # the output dictionary contains both the the input data and the output data. @@ -3379,21 +3380,17 @@ def prepare_outputs(pa_inputs, inputs, processed_inputs, check_same_num_examples else: return processed_inputs - def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0): + def apply_function(pa_inputs, indices, offset=0): """Utility to apply the function on a selection of columns.""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset) processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) - return prepare_outputs( - pa_inputs, inputs, processed_inputs, check_same_num_examples=check_same_num_examples - ) + return prepare_outputs(pa_inputs, inputs, processed_inputs) - async def async_apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0): + async def async_apply_function(pa_inputs, indices, offset=0): """Utility to apply the function on a selection of columns. Same code but async""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset) processed_inputs = await function(*fn_args, *additional_args, **fn_kwargs) - return prepare_outputs( - pa_inputs, inputs, processed_inputs, check_same_num_examples=check_same_num_examples - ) + return prepare_outputs(pa_inputs, inputs, processed_inputs) def init_buffer_and_writer(): # Prepare output buffer and batched writer in memory or on file if we update the table @@ -3432,12 +3429,12 @@ def init_buffer_and_writer(): def iter_output_examples(shard_iterable): if inspect.iscoroutinefunction(function): - indices: List[int] = [] + indices: Union[List[int], List[List[int]]] = [] tasks: List[asyncio.Task] = [] loop = asyncio.get_event_loop() for i, example in shard_iterable: indices.append(i) - tasks.append(loop.create_task(async_apply_function_on_filtered_inputs(example, i, offset=offset))) + tasks.append(loop.create_task(async_apply_function(example, i, offset=offset))) # keep the total active tasks under a certain number if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL: done, pending = loop.run_until_complete( @@ -3454,7 +3451,7 @@ def iter_output_examples(shard_iterable): yield indices.pop(0), loop.run_until_complete(tasks.pop(0)) else: for i, example in shard_iterable: - yield i, apply_function_on_filtered_inputs(example, i, offset=offset) + yield i, apply_function(example, i, offset=offset) num_examples_progress_update = 0 # If `update_data` is True after processing the first example/batch, initalize these resources with `init_buffer_and_writer` @@ -3475,7 +3472,7 @@ def iter_output_examples(shard_iterable): else: num_rows = len(shard) if not drop_last_batch else len(shard) // batch_size * batch_size shard_iterable = zip( - range(0, num_rows, batch_size), + (list(range(i, min(i + batch_size, num_rows))) for i in range(0, num_rows, batch_size)), arrow_formatted_shard.iter(batch_size, drop_last_batch=drop_last_batch), ) if not batched: @@ -3504,24 +3501,10 @@ def iter_output_examples(shard_iterable): num_examples_progress_update = 0 else: _time = time.time() - for i, batch in shard_iterable: - num_examples_in_batch = len(batch) - indices = list( - range(*(slice(i, i + batch_size).indices(shard.num_rows))) - ) # Something simpler? - try: - batch = apply_function_on_filtered_inputs( - batch, - indices, - check_same_num_examples=len(shard.list_indexes()) > 0, - offset=offset, - ) - except NumExamplesMismatchError: - raise DatasetTransformationNotAllowedError( - "Using `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn't create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it." - ) from None + for i, batch in iter_output_examples(shard_iterable): + num_examples_in_batch = len(i) if update_data: - if i == 0: + if i and i[0] == 0: buf_writer, writer, tmp_file = init_buffer_and_writer() stack.enter_context(writer) if isinstance(batch, pa.Table): From 491875bcb7991ee64d122122bbfd1adf32208f24 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 12 Feb 2025 12:45:03 +0100 Subject: [PATCH 04/10] async map in iterable dataset --- src/datasets/arrow_dataset.py | 6 +- src/datasets/iterable_dataset.py | 170 +++++++++++++++++++------------ 2 files changed, 109 insertions(+), 67 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 659c4858858..133d4e75d5c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3427,7 +3427,7 @@ def init_buffer_and_writer(): ) return buf_writer, writer, tmp_file - def iter_output_examples(shard_iterable): + def iter_outputs(shard_iterable): if inspect.iscoroutinefunction(function): indices: Union[List[int], List[List[int]]] = [] tasks: List[asyncio.Task] = [] @@ -3477,7 +3477,7 @@ def iter_output_examples(shard_iterable): ) if not batched: _time = time.time() - for i, example in iter_output_examples(shard_iterable): + for i, example in iter_outputs(shard_iterable): if update_data: if i == 0: buf_writer, writer, tmp_file = init_buffer_and_writer() @@ -3501,7 +3501,7 @@ def iter_output_examples(shard_iterable): num_examples_progress_update = 0 else: _time = time.time() - for i, batch in iter_output_examples(shard_iterable): + for i, batch in iter_outputs(shard_iterable): num_examples_in_batch = len(i) if update_data: if i and i[0] == 0: diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 317cc0b1723..f2b154408ea 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1,4 +1,6 @@ +import asyncio import copy +import inspect import itertools import sys from collections import Counter @@ -1075,11 +1077,7 @@ def _iter(self): else: format_dict = None - if self.batched: - if self._state_dict: - self._state_dict["previous_state"] = self.ex_iterable.state_dict() - self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] = current_idx + def iter_batched_inputs(): for key, example in iterator: # If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset iterator_batch = ( @@ -1089,6 +1087,8 @@ def _iter(self): ) key_examples_list = [(key, example)] + list(iterator_batch) keys, examples = zip(*key_examples_list) + # the new key is the concatenation of the examples keys from the batch + key = "_".join(str(key) for key in keys) if ( self.drop_last_batch and self.batch_size is not None @@ -1098,40 +1098,106 @@ def _iter(self): return batch = _examples_to_batch(examples) batch = format_dict(batch) if format_dict else batch - # then apply the transform - inputs = batch - function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] - if self.with_indices: - function_args.append([current_idx + i for i in range(len(key_examples_list))]) - inputs_to_merge = dict(batch) - processed_inputs = self.function(*function_args, **self.fn_kwargs) - # this logic mimics the one in Dataset.map - if self.remove_columns: - for c in self.remove_columns: - if c in inputs_to_merge: - del inputs_to_merge[c] - if processed_inputs is inputs and c in processed_inputs: - del processed_inputs[c] - transformed_batch = {**inputs_to_merge, **processed_inputs} - if transformed_batch: - first_col = next(iter(transformed_batch)) - bad_cols = [ - col - for col in transformed_batch - if len(transformed_batch[col]) != len(transformed_batch[first_col]) - ] - if bad_cols: - raise ValueError( - f"Column lengths mismatch: columns {bad_cols} have length {[len(transformed_batch[col]) for col in bad_cols]} " - f"while {first_col} has length {len(transformed_batch[first_col])}." + indices = [current_idx + i for i in range(len(key_examples_list))] + yield indices, (key, batch) + + def iter_inputs(): + for key, example in iterator: + # If not batched, we can apply the transform and yield the example directly + # first copy the example, since we might drop some keys + example = dict(example) + example = format_dict(example) if format_dict else example + yield current_idx, (key, example) + + def validate_function_output(processed_inputs): + if self.batched and processed_inputs: + first_col = next(iter(processed_inputs)) + bad_cols = [ + col for col in processed_inputs if len(processed_inputs[col]) != len(processed_inputs[first_col]) + ] + if bad_cols: + raise ValueError( + f"Column lengths mismatch: columns {bad_cols} have length {[len(processed_inputs[col]) for col in bad_cols]} " + f"while {first_col} has length {len(processed_inputs[first_col])}." + ) + + def prepare_inputs(key_example, indices): + key, example = key_example + fn_args = [example] if self.input_columns is None else [example[col] for col in self.input_columns] + additional_args = () + if self.with_indices: + fn_args += (indices,) + inputs_to_merge = dict(example) + return inputs_to_merge, fn_args, additional_args, self.fn_kwargs + + def prepare_outputs(inputs, processed_inputs): + validate_function_output(processed_inputs) + # this logic mimics the one in Dataset.map + if self.remove_columns: + for c in self.remove_columns: + if c in inputs: + del inputs[c] + if processed_inputs is inputs and c in processed_inputs: + del processed_inputs[c] + transformed_inputs = {**inputs, **processed_inputs} + if self.features: + for c in self.features.keys(): + if c not in transformed_inputs: + transformed_inputs[c] = ( + [None] * len(transformed_inputs[next(iter(processed_inputs))]) if self.batched else None ) - if self.features: - for c in self.features.keys(): - if c not in transformed_batch: - transformed_batch[c] = [None] * len(transformed_batch[first_col]) - transformed_batch = self.features.decode_batch(transformed_batch) - # the new key is the concatenation of the examples keys from the batch - new_key = "_".join(str(key) for key in keys) + transformed_inputs = ( + self.features.decode_batch(transformed_inputs) + if self.batched + else self.features.decode_example(transformed_inputs) + ) + return transformed_inputs + + def apply_function(key_example, indices): + """Utility to apply the function on a selection of columns.""" + inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices) + processed_inputs = self.function(*fn_args, *additional_args, **fn_kwargs) + return prepare_outputs(inputs, processed_inputs) + + async def async_apply_function(key_example, indices): + """Utility to apply the function on a selection of columns. Same code but async""" + inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices) + processed_inputs = await self.function(*fn_args, *additional_args, **fn_kwargs) + return prepare_outputs(inputs, processed_inputs) + + def iter_outputs(): + inputs_iterator = iter_batched_inputs() if self.batched else iter_inputs() + if inspect.iscoroutinefunction(self.function): + indices: Union[List[int], List[List[int]]] = [] + tasks: List[asyncio.Task] = [] + loop = asyncio.get_event_loop() + for i, key_example in inputs_iterator: + indices.append(i) + tasks.append(loop.create_task(async_apply_function(key_example, i))) + # keep the total active tasks under a certain number + if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL: + done, pending = loop.run_until_complete( + asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + ) + while tasks and len(pending) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL: + done, pending = loop.run_until_complete( + asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + ) + # yield finished tasks + while tasks and tasks[0].done(): + yield indices.pop(0), tasks.pop(0).result() + while tasks: + yield indices.pop(0), loop.run_until_complete(tasks.pop(0)) + else: + for i, key_example in inputs_iterator: + yield i, apply_function(key_example, i) + + if self.batched: + if self._state_dict: + self._state_dict["previous_state"] = self.ex_iterable.state_dict() + self._state_dict["num_examples_since_previous_state"] = 0 + self._state_dict["previous_state_example_idx"] = current_idx + for key, transformed_batch in iter_outputs(): # yield one example at a time from the transformed batch for example in _batch_to_examples(transformed_batch): current_idx += 1 @@ -1140,37 +1206,13 @@ def _iter(self): if num_examples_to_skip > 0: num_examples_to_skip -= 1 continue - yield new_key, example + yield key, example if self._state_dict: self._state_dict["previous_state"] = self.ex_iterable.state_dict() self._state_dict["num_examples_since_previous_state"] = 0 self._state_dict["previous_state_example_idx"] = current_idx else: - for key, example in iterator: - # If not batched, we can apply the transform and yield the example directly - # first copy the example, since we might drop some keys - example = dict(example) - example = format_dict(example) if format_dict else example - # then apply the transform - inputs = example - function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] - if self.with_indices: - function_args.append(current_idx) - processed_inputs = self.function(*function_args, **self.fn_kwargs) - inputs_to_merge = dict(example) - # this logic mimics the one in Dataset.map - if self.remove_columns: - for c in self.remove_columns: - if c in inputs_to_merge: - del inputs_to_merge[c] - if processed_inputs is inputs and c in processed_inputs: - del processed_inputs[c] - transformed_example = {**inputs_to_merge, **processed_inputs} - if self.features: - for c in self.features.keys(): - if c not in transformed_example: - transformed_example[c] = None - transformed_example = self.features.decode_example(transformed_example) + for key, transformed_example in iter_outputs(): current_idx += 1 if self._state_dict: self._state_dict["previous_state_example_idx"] += 1 From 893ca6b4b142007debae67d4ba89262ee6394a8c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 12 Feb 2025 15:14:45 +0100 Subject: [PATCH 05/10] async filter --- src/datasets/arrow_dataset.py | 67 +++++++- src/datasets/iterable_dataset.py | 259 ++++++++----------------------- 2 files changed, 131 insertions(+), 195 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 133d4e75d5c..ee97c37b4cc 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3707,7 +3707,9 @@ def filter( indices = self.map( function=partial( - get_indices_from_mask_function, + async_get_indices_from_mask_function + if inspect.iscoroutinefunction(function) + else get_indices_from_mask_function, function, batched, with_indices, @@ -3719,7 +3721,7 @@ def filter( with_rank=True, features=Features({"indices": Value("uint64")}), batched=True, - batch_size=batch_size, + batch_size=batch_size if batched else 1, remove_columns=self.column_names, keep_in_memory=keep_in_memory, load_from_cache_file=load_from_cache_file, @@ -6337,7 +6339,7 @@ def get_indices_from_mask_function( if isinstance(mask, (pa.Array, pa.ChunkedArray)): mask = mask.to_pylist() else: - # we get batched data (to do less look-ups) but `function` only accepts one example + # we get batched data (to return less data than input) but `function` only accepts one example # therefore we need to call `function` on each example of the batch to get the mask *inputs, indices, rank = args mask = [] @@ -6371,3 +6373,62 @@ def get_indices_from_mask_function( indices_array = indices_mapping.column(0).take(indices_array) indices_array = indices_array.to_pylist() return {"indices": indices_array} + + +async def async_get_indices_from_mask_function( + function: Callable, + batched: bool, + with_indices: bool, + with_rank: bool, + input_columns: Optional[Union[str, List[str]]], + indices_mapping: Optional[Table] = None, + *args, + **fn_kwargs, +): + """same function but async""" + if batched: + # we extract indices and rank from args + *inputs, indices, rank = args + additional_args = () + if with_indices: + additional_args += (indices,) + if with_rank: + additional_args += (rank,) + mask = await function(*inputs, *additional_args, **fn_kwargs) + if isinstance(mask, (pa.Array, pa.ChunkedArray)): + mask = mask.to_pylist() + else: + # we get batched data (to return less data than input) but `function` only accepts one example + # therefore we need to call `function` on each example of the batch to get the mask + *inputs, indices, rank = args + mask = [] + if input_columns is None: + # inputs only contains a batch of examples + batch: dict = inputs[0] + num_examples = len(batch[next(iter(batch.keys()))]) + for i in range(num_examples): + example = {key: batch[key][i] for key in batch} + additional_args = () + if with_indices: + additional_args += (indices[i],) + if with_rank: + additional_args += (rank,) + mask.append(await function(example, *additional_args, **fn_kwargs)) + else: + # inputs is a list of columns + columns: List[List] = inputs + num_examples = len(columns[0]) + for i in range(num_examples): + input = [column[i] for column in columns] + additional_args = () + if with_indices: + additional_args += (indices[i],) + if with_rank: + additional_args += (rank,) + mask.append(await function(*input, *additional_args, **fn_kwargs)) + indices_array = [i for i, to_keep in zip(indices, mask) if to_keep] + if indices_mapping is not None: + indices_array = pa.array(indices_array, type=pa.uint64()) + indices_array = indices_mapping.column(0).take(indices_array) + indices_array = indices_array.to_pylist() + return {"indices": indices_array} diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index f2b154408ea..2cada6bc863 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -18,7 +18,13 @@ from . import config from .arrow_dataset import Dataset, DatasetInfoMixin from .features import Features -from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects +from .features.features import ( + FeatureType, + Value, + _align_features, + _check_if_features_can_be_aligned, + cast_to_python_objects, +) from .formatting import ( ArrowFormatter, PythonFormatter, @@ -1021,12 +1027,12 @@ def __init__( # batch_size should match for iter_arrow if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): raise ValueError( - f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has underlying iterable" + f"The {formatting.format_type.capitalize()}-formatted {type(self).__name__} has underlying iterable" f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." ) elif ex_iterable.batch_size != (batch_size if batched else 1): raise ValueError( - f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is" + f"The {formatting.format_type.capitalize()}-formatted {type(self).__name__} has batch_size={batch_size if batched else 1} which is" f"different from {ex_iterable.batch_size=} from its underlying iterable." ) @@ -1327,7 +1333,34 @@ def num_shards(self) -> int: return self.ex_iterable.num_shards -class FilteredExamplesIterable(_BaseExamplesIterable): +def _add_mask( + input: Union[dict, pa.Table], + mask: Union[bool, list, pa.Array, pa.ChunkedArray, pa.BooleanScalar], + mask_column_name: str, +): + if isinstance(input, pa.Table): + if not isinstance(mask, (list, pa.Array, pa.ChunkedArray)): + mask = [mask] + return input.add_column(mask_column_name, mask) + else: + return {mask_column_name: mask} + + +def add_mask(mask_function: Callable, input: Union[dict, pa.Table], *args, mask_column_name: str, **kwargs): + mask = mask_function(input, *args, **kwargs) + return _add_mask(input, mask, mask_column_name) + + +async def async_add_mask( + mask_function: Callable, input: Union[dict, pa.Table], *args, mask_column_name: str, **kwargs +): + mask = await mask_function(input, *args, **kwargs) + return _add_mask(input, mask, mask_column_name) + + +class FilteredExamplesIterable(MappedExamplesIterable): + mask_column_name = "===MASK===" + def __init__( self, ex_iterable: _BaseExamplesIterable, @@ -1339,207 +1372,48 @@ def __init__( fn_kwargs: Optional[dict] = None, formatting: Optional["FormattingConfig"] = None, ): - super().__init__() - self.ex_iterable = ex_iterable - self.function = function - self.batched = batched - self.batch_size = batch_size - self.with_indices = with_indices - self.input_columns = input_columns - self.fn_kwargs = fn_kwargs or {} - self.formatting = formatting # required for iter_arrow - # sanity checks - if formatting and formatting.is_table: - # batch_size should match for iter_arrow - if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): - raise ValueError( - f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has underlying iterable" - f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." - ) - elif ex_iterable.batch_size != (batch_size if batched else 1): - raise ValueError( - f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is" - f"different from {ex_iterable.batch_size=} from its underlying iterable." - ) - - @property - def iter_arrow(self): - if self.formatting and self.formatting.format_type == "arrow": - return self._iter_arrow - - @property - def is_typed(self): - return self.ex_iterable.is_typed - - @property - def features(self): - return self.ex_iterable.features - - def _init_state_dict(self) -> dict: - self._state_dict = { - "ex_iterable": self.ex_iterable._init_state_dict(), - "previous_state": None, - "num_examples_since_previous_state": 0, - "previous_state_example_idx": 0, - } - return self._state_dict - - def __iter__(self): - if self.formatting and self.formatting.format_type == "arrow": - formatter = PythonFormatter() - for key, pa_table in self._iter_arrow(max_chunksize=1): - yield key, formatter.format_row(pa_table) + self.mask_function = function + if ex_iterable.is_typed: + features = Features({**ex_iterable.features, self.mask_column_name: Value("bool")}) else: - yield from self._iter() + features = None + super().__init__( + ex_iterable=ex_iterable, + function=partial( + async_add_mask if inspect.iscoroutinefunction(function) else add_mask, + function, + mask_column_name=self.mask_column_name, + ), + with_indices=with_indices, + input_columns=input_columns, + batched=batched, + batch_size=batch_size, + fn_kwargs=fn_kwargs, + formatting=formatting, + features=features, + ) def _iter(self): - current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0 - if self._state_dict and self._state_dict["previous_state"]: - self.ex_iterable.load_state_dict(self._state_dict["previous_state"]) - num_examples_to_skip = self._state_dict["num_examples_since_previous_state"] - else: - num_examples_to_skip = 0 - iterator = iter(self.ex_iterable) - - if self.formatting: - formatter = get_formatter(self.formatting.format_type) - format_dict = ( - formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects - ) - else: - format_dict = None - - if self.batched: - if self._state_dict: - self._state_dict["previous_state"] = self.ex_iterable.state_dict() - self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] = current_idx - for key, example in iterator: - # If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset - iterator_batch = ( - iterator - if self.batch_size is None or self.batch_size <= 0 - else islice(iterator, self.batch_size - 1) - ) - key_examples_list = [(key, example)] + list(iterator_batch) - keys, examples = zip(*key_examples_list) - batch = _examples_to_batch(examples) - batch = format_dict(batch) if format_dict else batch - # then compute the mask for the batch - inputs = batch - function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] - if self.with_indices: - function_args.append([current_idx + i for i in range(len(key_examples_list))]) - mask = self.function(*function_args, **self.fn_kwargs) - # yield one example at a time from the batch - for key_example, to_keep in zip(key_examples_list, mask): - current_idx += 1 - if self._state_dict: - self._state_dict["num_examples_since_previous_state"] += 1 - if num_examples_to_skip > 0: - num_examples_to_skip -= 1 - continue - if to_keep: - yield key_example - if self._state_dict: - self._state_dict["previous_state"] = self.ex_iterable.state_dict() - self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] = current_idx - else: - for key, example in iterator: - # If not batched, we can apply the filtering function direcly - example = dict(example) - inputs = format_dict(example) if format_dict else example - function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] - if self.with_indices: - function_args.append(current_idx) - to_keep = self.function(*function_args, **self.fn_kwargs) - current_idx += 1 - if self._state_dict: - self._state_dict["previous_state_example_idx"] += 1 - if to_keep: - yield key, example + for key, example in super()._iter(): + example = dict(example) + if example.pop(self.mask_column_name): + yield key, example def _iter_arrow(self, max_chunksize: Optional[int] = None): - formatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter() - if self.ex_iterable.iter_arrow: - iterator = self.ex_iterable.iter_arrow() - else: - iterator = _convert_to_arrow(self.ex_iterable, batch_size=self.batch_size if self.batched else 1) - - if self._state_dict and self._state_dict["previous_state"]: - self.ex_iterable.load_state_dict(self._state_dict["previous_state"]) - num_examples_to_skip = self._state_dict["num_examples_since_previous_state"] - else: - num_examples_to_skip = 0 - if self._state_dict and max_chunksize is not None: - self._state_dict["previous_state"] = self.ex_iterable.state_dict() - self._state_dict["num_examples_since_previous_state"] = 0 - current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0 - for key, pa_table in iterator: - if ( - self.batched - and self.batch_size is not None - and len(pa_table) < self.batch_size - and self.drop_last_batch - ): - return - - function_args = ( - [formatter.format_batch(pa_table)] - if self.input_columns is None - else [pa_table[col] for col in self.input_columns] - ) - if self.with_indices: - if self.batched: - function_args.append([current_idx + i for i in range(len(pa_table))]) - else: - function_args.append(current_idx) - # then apply the transform - output = self.function(*function_args, **self.fn_kwargs) - mask = _table_output_to_arrow(output) - if not isinstance(mask, (bool, pa.Array, pa.BooleanScalar)): - raise TypeError( - f"Provided `function` which is applied to {formatter.table_type} returns a variable of type " - f"{type(output)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset." - ) - # return output - if self.batched: - output_table = pa_table.filter(mask) - elif mask.as_py() if isinstance(mask, pa.BooleanScalar) else mask: - output_table = pa_table - else: - output_table = pa_table.slice(0, 0) - - if max_chunksize is None: - current_idx += len(pa_table) - if self._state_dict: - self._state_dict["previous_state_example_idx"] += len(pa_table) - if len(output_table) > 0: - yield key, output_table - else: - for i, pa_subtable in enumerate(output_table.to_reader(max_chunksize=max_chunksize)): - current_idx += 1 - if self._state_dict: - self._state_dict["num_examples_since_previous_state"] += 1 - if num_examples_to_skip > 0: - num_examples_to_skip -= 1 - continue - yield f"{key}_{i}", pa_subtable - if self._state_dict: - self._state_dict["previous_state"] = self.ex_iterable.state_dict() - self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] += len(pa_table) + for key, pa_table in super()._iter_arrow(max_chunksize=max_chunksize): + mask = pa_table[self.mask_column_name] + yield key, pa_table.drop(self.mask_column_name).filter(mask) def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable": """Shuffle the wrapped examples iterable.""" return FilteredExamplesIterable( self.ex_iterable.shuffle_data_sources(seed), - function=self.function, + function=self.mask_function, with_indices=self.with_indices, input_columns=self.input_columns, batched=self.batched, batch_size=self.batch_size, + fn_kwargs=self.fn_kwargs, formatting=self.formatting, ) @@ -1547,11 +1421,12 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "F """Keep only the requested shard.""" return FilteredExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), - function=self.function, + function=self.mask_function, with_indices=self.with_indices, input_columns=self.input_columns, batched=self.batched, batch_size=self.batch_size, + fn_kwargs=self.fn_kwargs, formatting=self.formatting, ) From 23f6c9b121996f9568f49c76c21eb71ca2368a26 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 12 Feb 2025 15:30:38 +0100 Subject: [PATCH 06/10] fix ci --- src/datasets/iterable_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 2cada6bc863..edb189b972a 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1340,8 +1340,8 @@ def _add_mask( ): if isinstance(input, pa.Table): if not isinstance(mask, (list, pa.Array, pa.ChunkedArray)): - mask = [mask] - return input.add_column(mask_column_name, mask) + mask = pa.array([mask], type=pa.bool_()) + return input.append_column(mask_column_name, mask) else: return {mask_column_name: mask} From 95b274138d1f671cd77e95ebc48013e1f712cbf2 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 12 Feb 2025 15:52:57 +0100 Subject: [PATCH 07/10] fix ci --- src/datasets/iterable_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index edb189b972a..a95b276d89b 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1133,17 +1133,17 @@ def prepare_inputs(key_example, indices): additional_args = () if self.with_indices: fn_args += (indices,) - inputs_to_merge = dict(example) - return inputs_to_merge, fn_args, additional_args, self.fn_kwargs + inputs = dict(example) + return inputs, fn_args, additional_args, self.fn_kwargs - def prepare_outputs(inputs, processed_inputs): + def prepare_outputs(key_example, inputs, processed_inputs): validate_function_output(processed_inputs) # this logic mimics the one in Dataset.map if self.remove_columns: for c in self.remove_columns: if c in inputs: del inputs[c] - if processed_inputs is inputs and c in processed_inputs: + if processed_inputs is key_example[1] and c in processed_inputs: del processed_inputs[c] transformed_inputs = {**inputs, **processed_inputs} if self.features: @@ -1163,13 +1163,13 @@ def apply_function(key_example, indices): """Utility to apply the function on a selection of columns.""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices) processed_inputs = self.function(*fn_args, *additional_args, **fn_kwargs) - return prepare_outputs(inputs, processed_inputs) + return prepare_outputs(key_example, inputs, processed_inputs) async def async_apply_function(key_example, indices): """Utility to apply the function on a selection of columns. Same code but async""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices) processed_inputs = await self.function(*fn_args, *additional_args, **fn_kwargs) - return prepare_outputs(inputs, processed_inputs) + return prepare_outputs(key_example, inputs, processed_inputs) def iter_outputs(): inputs_iterator = iter_batched_inputs() if self.batched else iter_inputs() From 5b67d2045445201de365a2116d632cd5cb8d8779 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 12 Feb 2025 15:58:45 +0100 Subject: [PATCH 08/10] minor ci fixes --- tests/conftest.py | 2 +- tests/test_arrow_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f4e79eb5bf3..e9bb542c954 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,7 +52,7 @@ def set_sqlalchemy_silence_uber_warning(monkeypatch): # To be removed once SQLAlchemy 2.0 supported try: monkeypatch.setattr("sqlalchemy.util.deprecations.SILENCE_UBER_WARNING", True) - except AttributeError: + except (ModuleNotFoundError, AttributeError): pass diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 3a8e19b586a..a50ac62757c 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -4165,7 +4165,7 @@ def test_dummy_dataset_serialize_fs(dataset, mockfs): [ "relative/path", "/absolute/path", - "s3://bucket/relative/path", + "hf://bucket/relative/path", "hdfs://relative/path", "hdfs:///absolute/path", ], @@ -4179,7 +4179,7 @@ def test_build_local_temp_path(uri_or_path): assert ( "hdfs://" not in path_relative_to_tmp_dir - and "s3://" not in path_relative_to_tmp_dir + and "hf://" not in path_relative_to_tmp_dir and not local_temp_path.startswith(extracted_path_without_anchor) and local_temp_path.endswith(extracted_path_without_anchor) ), f"Local temp path: {local_temp_path}" From 861fc710c41bf43dc5beb06034a8091e02496098 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 12 Feb 2025 16:24:19 +0100 Subject: [PATCH 09/10] add tests --- src/datasets/arrow_dataset.py | 5 +++- src/datasets/iterable_dataset.py | 5 +++- tests/test_arrow_dataset.py | 46 ++++++++++++++++++++++++++++++++ tests/test_iterable_dataset.py | 46 ++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index ee97c37b4cc..864318e7234 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3431,7 +3431,10 @@ def iter_outputs(shard_iterable): if inspect.iscoroutinefunction(function): indices: Union[List[int], List[List[int]]] = [] tasks: List[asyncio.Task] = [] - loop = asyncio.get_event_loop() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() for i, example in shard_iterable: indices.append(i) tasks.append(loop.create_task(async_apply_function(example, i, offset=offset))) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index a95b276d89b..5faf95ff8ea 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1176,7 +1176,10 @@ def iter_outputs(): if inspect.iscoroutinefunction(self.function): indices: Union[List[int], List[List[int]]] = [] tasks: List[asyncio.Task] = [] - loop = asyncio.get_event_loop() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() for i, key_example in inputs_iterator: indices.append(i) tasks.append(loop.create_task(async_apply_function(key_example, i))) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index a50ac62757c..2b490607179 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import copy import itertools @@ -7,6 +8,7 @@ import re import sys import tempfile +import time from functools import partial from pathlib import Path from unittest import TestCase @@ -4422,6 +4424,50 @@ def f(x): assert outputs == {"a": [{"nested": [[i]]} for i in [-1, -1, 2, 3]]} +def test_map_async(): + dset = Dataset.from_dict({"x": range(100)}) + + async def f(example): + await asyncio.sleep(0.1) + return {"y": 1} + + _start = time.time() + out = dset.map(f) + assert time.time() - _start < 2.0 + assert out[0]["y"] == 1 + + async def f(batch): + await asyncio.sleep(0.1) + return {"y": [1] * len(batch["x"])} + + _start = time.time() + out = dset.map(f, batched=True) + assert time.time() - _start < 2.0 + assert out[0]["y"] == 1 + + +def test_filter_async(): + dset = Dataset.from_dict({"x": range(100)}) + + async def f(example): + await asyncio.sleep(0.1) + return example["x"] == 42 + + _start = time.time() + out = dset.filter(f) + assert time.time() - _start < 2.0 + assert len(out) == 1 + + async def f(batch): + await asyncio.sleep(0.1) + return [x == 42 for x in batch["x"]] + + _start = time.time() + out = dset.filter(f, batched=True) + assert time.time() - _start < 2.0 + assert len(out) == 1 + + def test_dataset_getitem_raises(): ds = Dataset.from_dict({"a": [0, 1, 2, 3]}) with pytest.raises(TypeError): diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index bd79863f9c3..2f0fc051654 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1,4 +1,6 @@ +import asyncio import pickle +import time from copy import deepcopy from itertools import chain, cycle, islice from unittest.mock import patch @@ -1143,6 +1145,50 @@ def test_filtered_examples_iterable_input_columns(n, func, batched, batch_size, assert_load_state_dict_resumes_iteration(ex_iterable) +def test_map_async(): + dset = Dataset.from_dict({"x": range(100)}).to_iterable_dataset() + + async def f(example): + await asyncio.sleep(0.1) + return {"y": 1} + + _start = time.time() + out = dset.map(f) + assert time.time() - _start < 2.0 + assert next(iter(out))["y"] == 1 + + async def f(batch): + await asyncio.sleep(0.1) + return {"y": [1] * len(batch["x"])} + + _start = time.time() + out = dset.map(f, batched=True) + assert time.time() - _start < 2.0 + assert next(iter(out))["y"] == 1 + + +def test_filter_async(): + dset = Dataset.from_dict({"x": range(100)}).to_iterable_dataset() + + async def f(example): + await asyncio.sleep(0.1) + return example["x"] == 42 + + _start = time.time() + out = dset.filter(f) + assert time.time() - _start < 2.0 + assert len(list(out)) == 1 + + async def f(batch): + await asyncio.sleep(0.1) + return [x == 42 for x in batch["x"]] + + _start = time.time() + out = dset.filter(f, batched=True) + assert time.time() - _start < 2.0 + assert len(list(out)) == 1 + + def test_skip_examples_iterable(): total, count = 10, 2 base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": total}) From d21dec2382b6560138672b172b9bf6de7192bf0f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 13 Feb 2025 13:00:39 +0100 Subject: [PATCH 10/10] docs --- docs/source/process.mdx | 46 ++++++++++++++++++++++++++++++++ src/datasets/arrow_dataset.py | 8 ++++++ src/datasets/iterable_dataset.py | 8 ++++++ 3 files changed, 62 insertions(+) diff --git a/docs/source/process.mdx b/docs/source/process.mdx index 712dac4de4c..0aa091a63dc 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -502,6 +502,52 @@ Use [`~Dataset.map`] to apply the function over the whole dataset: For each original sentence, RoBERTA augmented a random word with three alternatives. The original word `distorting` is supplemented by `withholding`, `suppressing`, and `destroying`. +### Run asynchronous calls + +Asynchronous functions are useful to call API endpoints in parallel, for example to download content like images or call a model endpoint. + +You can define an asynchronous function using the `async` and `await` keywords, here is an example function to call a chat model from Hugging Face: + +```python +>>> import aiohttp +>>> import asyncio +>>> from huggingface_hub import get_token +>>> sem = asyncio.Semaphore(20) # max number of simultaneous queries +>>> async def query_model(model, prompt): +... api_url = f"https://api-inference.huggingface.co/models/{model}/v1/chat/completions" +... headers = {"Authorization": f"Bearer {get_token()}", "Content-Type": "application/json"} +... json = {"messages": [{"role": "user", "content": prompt}], "max_tokens": 20, "seed": 42} +... async with sem, aiohttp.ClientSession() as session, session.post(api_url, headers=headers, json=json) as response: +... output = await response.json() +... return {"Output": output["choices"][0]["message"]["content"]} +``` + +Asynchronous functions run in parallel, which accelerates the process a lot. The same code takes a lot more time if it's run sequentially, because it does nothing while waiting for the model response. It is generally recommended to use `async` / `await` when you function has to wait for a response from an API for example, or if it downloads data and it can take some time. + +Note the presence of a `Semaphore`: it sets the maximum number of queries that can run in parallel. It is recommended to use a `Semaphore` when calling APIs to avoid rate limit errors. + +Let's use it to call the [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) model and ask it to return the main topic of each math problem in the [Maxwell-Jia/AIME_2024](https://huggingface.co/Maxwell-Jia/AIME_2024) dataset: + +```python +>>> from datasets import load_dataset +>>> ds = load_dataset("Maxwell-Jia/AIME_2024", split="train") +>>> model = "microsoft/Phi-3-mini-4k-instruct" +>>> prompt = 'What is this text mainly about ? Here is the text:\n\n```\n{Problem}\n```\n\nReply using one or two words max, e.g. "The main topic is Linear Algebra".' +>>> async def get_topic(example): +... return await query_model(model, prompt.format(Problem=example['Problem'])) +>>> ds = ds.map(get_topic) +>>> ds[0] +{'ID': '2024-II-4', + 'Problem': 'Let $x,y$ and $z$ be positive real numbers that...', + 'Solution': 'Denote $\\log_2(x) = a$, $\\log_2(y) = b$, and..., + 'Answer': 33, + 'Output': 'The main topic is Logarithms.'} +``` + +Here, [`Dataset.map`] runs many `get_topic` function asynchronously so it doesn't have to wait for every single model response which would take a lot of time to do sequentially. + +By default, [`Dataset.map`] runs up to one thousand queries in parallel, so don't forget to set the maximum number of queries that can run in parallel with a `Semaphore`, otherwise the model could return rate limit errors or overload. For advanced use cases, you can change the maximum number of queries in parallel in `datasets.config`. + ### Process multiple splits Many datasets have splits that can be processed simultaneously with [`DatasetDict.map`]. For example, tokenize the `sentence1` field in the train and test split by: diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 864318e7234..61ae4e4a988 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2871,6 +2871,9 @@ def map( Note that the last batch may have less than `n` examples. A batch is a dictionary, e.g. a batch of `n` examples is `{"text": ["Hello there !"] * n}`. + If the function is asynchronous, then `map` will run your function in parallel, with up to one thousand simulatenous calls. + It is recommended to use a `asyncio.Semaphore` in your function if you want to set a maximum number of operations that can run at the same time. + Args: function (`Callable`): Function with one of the following signatures: @@ -2880,6 +2883,7 @@ def map( - `function(batch: Dict[str, List], *extra_args) -> Dict[str, List]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) For advanced usage, the function can also return a `pyarrow.Table`. + If the function is asynchronous, then `map` will run your function in parallel. Moreover if your function returns nothing (`None`), then `map` will run your function and return the dataset unchanged. If no function is provided, default to identity function: `lambda x: x`. with_indices (`bool`, defaults to `False`): @@ -3633,6 +3637,9 @@ def filter( """Apply a filter function to all the elements in the table in batches and update the table so that the dataset only includes examples according to the filter function. + If the function is asynchronous, then `filter` will run your function in parallel, with up to one thousand simulatenous calls (configurable). + It is recommended to use a `asyncio.Semaphore` in your function if you want to set a maximum number of operations that can run at the same time. + Args: function (`Callable`): Callable with one of the following signatures: @@ -3641,6 +3648,7 @@ def filter( - `function(batch: Dict[str, List]) -> List[bool]` if `batched=True` and `with_indices=False` and `with_rank=False` - `function(batch: Dict[str, List], *extra_args) -> List[bool]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + If the function is asynchronous, then `filter` will run your function in parallel. If no function is provided, defaults to an always `True` function: `lambda x: True`. with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 5faf95ff8ea..da5132943a7 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2406,6 +2406,9 @@ def map( Note that the last batch may have less than `n` examples. A batch is a dictionary, e.g. a batch of `n` examples is `{"text": ["Hello there !"] * n}`. + If the function is asynchronous, then `map` will run your function in parallel, with up to one thousand simulatenous calls. + It is recommended to use a `asyncio.Semaphore` in your function if you want to set a maximum number of operations that can run at the same time. + Args: function (`Callable`, *optional*, defaults to `None`): Function applied on-the-fly on the examples when you iterate on the dataset. @@ -2417,6 +2420,7 @@ def map( - `function(batch: Dict[str, List], indices: List[int]) -> Dict[str, List]` if `batched=True` and `with_indices=True` For advanced usage, the function can also return a `pyarrow.Table`. + If the function is asynchronous, then `map` will run your function in parallel. Moreover if your function returns nothing (`None`), then `map` will run your function and return the dataset unchanged. If no function is provided, default to identity function: `lambda x: x`. with_indices (`bool`, defaults to `False`): @@ -2537,6 +2541,9 @@ def filter( """Apply a filter function to all the elements so that the dataset only includes examples according to the filter function. The filtering is done on-the-fly when iterating over the dataset. + If the function is asynchronous, then `filter` will run your function in parallel, with up to one thousand simulatenous calls (configurable). + It is recommended to use a `asyncio.Semaphore` in your function if you want to set a maximum number of operations that can run at the same time. + Args: function (`Callable`): Callable with one of the following signatures: @@ -2546,6 +2553,7 @@ def filter( - `function(example: Dict[str, List]) -> List[bool]` if `with_indices=False, batched=True` - `function(example: Dict[str, List], indices: List[int]) -> List[bool]` if `with_indices=True, batched=True` + If the function is asynchronous, then `filter` will run your function in parallel. If no function is provided, defaults to an always True function: `lambda x: True`. with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`.