Skip to content

Commit 00fe4d1

Browse files
sjrlAmnah199
andauthored
feat: Add run async for AzureOpenAIChatGenerator (#8948)
* Add tests for run_async * Add reno * Add async client * Add init test * Add comment * Fix test * Update releasenotes/notes/run-async-azure-54450f0c2495f5c8.yaml Co-authored-by: Amna Mubashar <[email protected]> --------- Co-authored-by: Amna Mubashar <[email protected]>
1 parent 52a0282 commit 00fe4d1

File tree

3 files changed

+88
-12
lines changed

3 files changed

+88
-12
lines changed

haystack/components/generators/chat/azure.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Callable, Dict, List, Optional
77

88
# pylint: disable=import-error
9-
from openai.lib.azure import AzureOpenAI
9+
from openai.lib.azure import AsyncAzureOpenAI, AzureOpenAI
1010

1111
from haystack import component, default_from_dict, default_to_dict, logging
1212
from haystack.components.generators.chat import OpenAIChatGenerator
@@ -154,17 +154,20 @@ def __init__( # pylint: disable=too-many-positional-arguments
154154
self.tools = tools
155155
self.tools_strict = tools_strict
156156

157-
self.client = AzureOpenAI(
158-
api_version=api_version,
159-
azure_endpoint=azure_endpoint,
160-
azure_deployment=azure_deployment,
161-
api_key=api_key.resolve_value() if api_key is not None else None,
162-
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
163-
organization=organization,
164-
timeout=self.timeout,
165-
max_retries=self.max_retries,
166-
default_headers=self.default_headers,
167-
)
157+
client_args: Dict[str, Any] = {
158+
"api_version": api_version,
159+
"azure_endpoint": azure_endpoint,
160+
"azure_deployment": azure_deployment,
161+
"api_key": api_key.resolve_value() if api_key is not None else None,
162+
"azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None,
163+
"organization": organization,
164+
"timeout": self.timeout,
165+
"max_retries": self.max_retries,
166+
"default_headers": self.default_headers,
167+
}
168+
169+
self.client = AzureOpenAI(**client_args)
170+
self.async_client = AsyncAzureOpenAI(**client_args)
168171

169172
def to_dict(self) -> Dict[str, Any]:
170173
"""
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
---
3+
features:
4+
- |
5+
Add `run_async` method to `AzureOpenAIChatGenerator`. This method uses `AsyncAzureOpenAI`
6+
to generate chat completions and supports the same parameters as the `run` method. It returns a coroutine
7+
that can be awaited.

test/components/generators/chat/test_azure.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,69 @@ def test_live_run_with_tools(self, tools):
242242
assert message.meta["finish_reason"] == "tool_calls"
243243

244244
# additional tests intentionally omitted as they are covered by test_openai.py
245+
246+
247+
class TestAzureOpenAIChatGeneratorAsync:
248+
def test_init_should_also_create_async_client_with_same_args(self, tools):
249+
component = AzureOpenAIChatGenerator(
250+
api_key=Secret.from_token("test-api-key"),
251+
azure_endpoint="some-non-existing-endpoint",
252+
streaming_callback=print_streaming_chunk,
253+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
254+
tools=tools,
255+
tools_strict=True,
256+
)
257+
assert component.async_client.api_key == "test-api-key"
258+
assert component.azure_deployment == "gpt-4o-mini"
259+
assert component.streaming_callback is print_streaming_chunk
260+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
261+
assert component.tools == tools
262+
assert component.tools_strict
263+
264+
@pytest.mark.integration
265+
@pytest.mark.skipif(
266+
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
267+
reason=(
268+
"Please export env variables called AZURE_OPENAI_API_KEY containing "
269+
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
270+
"the Azure OpenAI endpoint URL to run this test."
271+
),
272+
)
273+
@pytest.mark.asyncio
274+
async def test_live_run_async(self):
275+
chat_messages = [ChatMessage.from_user("What's the capital of France")]
276+
component = AzureOpenAIChatGenerator(generation_kwargs={"n": 1})
277+
results = await component.run_async(chat_messages)
278+
assert len(results["replies"]) == 1
279+
message: ChatMessage = results["replies"][0]
280+
assert "Paris" in message.text
281+
assert "gpt-4o" in message.meta["model"]
282+
assert message.meta["finish_reason"] == "stop"
283+
284+
@pytest.mark.integration
285+
@pytest.mark.skipif(
286+
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
287+
reason=(
288+
"Please export env variables called AZURE_OPENAI_API_KEY containing "
289+
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
290+
"the Azure OpenAI endpoint URL to run this test."
291+
),
292+
)
293+
@pytest.mark.asyncio
294+
async def test_live_run_with_tools_async(self, tools):
295+
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
296+
component = AzureOpenAIChatGenerator(tools=tools)
297+
results = await component.run_async(chat_messages)
298+
assert len(results["replies"]) == 1
299+
message = results["replies"][0]
300+
301+
assert not message.texts
302+
assert not message.text
303+
assert message.tool_calls
304+
tool_call = message.tool_call
305+
assert isinstance(tool_call, ToolCall)
306+
assert tool_call.tool_name == "weather"
307+
assert tool_call.arguments == {"city": "Paris"}
308+
assert message.meta["finish_reason"] == "tool_calls"
309+
310+
# additional tests intentionally omitted as they are covered by test_openai.py

0 commit comments

Comments
 (0)