Skip to content

Commit 339e9dc

Browse files
authored
Support async functions in map() (#7384)
* support async functions in map() * simplify code * batched async * async map in iterable dataset * async filter * fix ci * fix ci * minor ci fixes * add tests * docs
1 parent 3a4e74a commit 339e9dc

File tree

7 files changed

+462
-301
lines changed

7 files changed

+462
-301
lines changed

docs/source/process.mdx

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,52 @@ Use [`~Dataset.map`] to apply the function over the whole dataset:
502502

503503
For each original sentence, RoBERTA augmented a random word with three alternatives. The original word `distorting` is supplemented by `withholding`, `suppressing`, and `destroying`.
504504

505+
### Run asynchronous calls
506+
507+
Asynchronous functions are useful to call API endpoints in parallel, for example to download content like images or call a model endpoint.
508+
509+
You can define an asynchronous function using the `async` and `await` keywords, here is an example function to call a chat model from Hugging Face:
510+
511+
```python
512+
>>> import aiohttp
513+
>>> import asyncio
514+
>>> from huggingface_hub import get_token
515+
>>> sem = asyncio.Semaphore(20) # max number of simultaneous queries
516+
>>> async def query_model(model, prompt):
517+
... api_url = f"https://api-inference.huggingface.co/models/{model}/v1/chat/completions"
518+
... headers = {"Authorization": f"Bearer {get_token()}", "Content-Type": "application/json"}
519+
... json = {"messages": [{"role": "user", "content": prompt}], "max_tokens": 20, "seed": 42}
520+
... async with sem, aiohttp.ClientSession() as session, session.post(api_url, headers=headers, json=json) as response:
521+
... output = await response.json()
522+
... return {"Output": output["choices"][0]["message"]["content"]}
523+
```
524+
525+
Asynchronous functions run in parallel, which accelerates the process a lot. The same code takes a lot more time if it's run sequentially, because it does nothing while waiting for the model response. It is generally recommended to use `async` / `await` when you function has to wait for a response from an API for example, or if it downloads data and it can take some time.
526+
527+
Note the presence of a `Semaphore`: it sets the maximum number of queries that can run in parallel. It is recommended to use a `Semaphore` when calling APIs to avoid rate limit errors.
528+
529+
Let's use it to call the [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) model and ask it to return the main topic of each math problem in the [Maxwell-Jia/AIME_2024](https://huggingface.co/Maxwell-Jia/AIME_2024) dataset:
530+
531+
```python
532+
>>> from datasets import load_dataset
533+
>>> ds = load_dataset("Maxwell-Jia/AIME_2024", split="train")
534+
>>> model = "microsoft/Phi-3-mini-4k-instruct"
535+
>>> prompt = 'What is this text mainly about ? Here is the text:\n\n```\n{Problem}\n```\n\nReply using one or two words max, e.g. "The main topic is Linear Algebra".'
536+
>>> async def get_topic(example):
537+
... return await query_model(model, prompt.format(Problem=example['Problem']))
538+
>>> ds = ds.map(get_topic)
539+
>>> ds[0]
540+
{'ID': '2024-II-4',
541+
'Problem': 'Let $x,y$ and $z$ be positive real numbers that...',
542+
'Solution': 'Denote $\\log_2(x) = a$, $\\log_2(y) = b$, and...,
543+
'Answer': 33,
544+
'Output': 'The main topic is Logarithms.'}
545+
```
546+
547+
Here, [`Dataset.map`] runs many `get_topic` function asynchronously so it doesn't have to wait for every single model response which would take a lot of time to do sequentially.
548+
549+
By default, [`Dataset.map`] runs up to one thousand queries in parallel, so don't forget to set the maximum number of queries that can run in parallel with a `Semaphore`, otherwise the model could return rate limit errors or overload. For advanced use cases, you can change the maximum number of queries in parallel in `datasets.config`.
550+
505551
### Process multiple splits
506552

507553
Many datasets have splits that can be processed simultaneously with [`DatasetDict.map`]. For example, tokenize the `sentence1` field in the train and test split by:

0 commit comments

Comments
 (0)