Skip to content

Commit 605f1f0

Browse files
committed
[WIP] Azure Client integration
1 parent 98177d9 commit 605f1f0

File tree

2 files changed

+460
-0
lines changed

2 files changed

+460
-0
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""Azure OpenAI ModelClient integration."""
2+
3+
import os
4+
from typing import Dict, Optional, Any, Callable, Literal
5+
import backoff
6+
import logging
7+
8+
from adalflow.core.model_client import ModelClient
9+
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput
10+
11+
# optional import
12+
from adalflow.utils.lazy_import import safe_import, OptionalPackages
13+
14+
openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])
15+
16+
from openai import AzureOpenAI, AsyncAzureOpenAI, Stream
17+
from openai import (
18+
APITimeoutError,
19+
InternalServerError,
20+
RateLimitError,
21+
UnprocessableEntityError,
22+
BadRequestError,
23+
)
24+
from openai.types import (
25+
Completion,
26+
CreateEmbeddingResponse,
27+
)
28+
from openai.types.chat import ChatCompletionChunk, ChatCompletion
29+
from adalflow.components.model_client.utils import parse_embedding_response
30+
31+
log = logging.getLogger(__name__)
32+
33+
def get_first_message_content(completion: ChatCompletion) -> str:
34+
"""When we only need the content of the first message.
35+
It is the default parser for chat completion."""
36+
return completion.choices[0].message.content
37+
38+
def parse_stream_response(completion: ChatCompletionChunk) -> str:
39+
"""Parse the response of the stream API."""
40+
return completion.choices[0].delta.content
41+
42+
def handle_streaming_response(generator: Stream[ChatCompletionChunk]):
43+
"""Handle the streaming response."""
44+
for completion in generator:
45+
log.debug(f"Raw chunk completion: {completion}")
46+
parsed_content = parse_stream_response(completion)
47+
yield parsed_content
48+
49+
class AzureClient(ModelClient):
50+
"""A component wrapper for the Azure OpenAI API client.
51+
52+
This client supports both chat completion and embedding APIs through Azure OpenAI.
53+
It can be used with both sync and async operations.
54+
55+
Args:
56+
api_key (Optional[str]): Azure OpenAI API key
57+
api_version (Optional[str]): API version to use
58+
azure_endpoint (Optional[str]): Azure OpenAI endpoint URL (e.g., https://<resource-name>.openai.azure.com/)
59+
base_url (Optional[str]): Alternative base URL format (e.g., https://<model-deployment-name>.<region>.models.ai.azure.com)
60+
chat_completion_parser (Optional[Callable]): Function to parse chat completions
61+
input_type (Literal["text", "messages"]): Format for input
62+
63+
Environment Variables:
64+
AZURE_OPENAI_API_KEY: API key
65+
AZURE_OPENAI_ENDPOINT: Endpoint URL (new format)
66+
AZURE_BASE_URL: Base URL (alternative format)
67+
AZURE_OPENAI_VERSION: API version
68+
69+
Example:
70+
>>> from adalflow.components.model_client import AzureClient
71+
>>> client = AzureClient()
72+
>>> generator = Generator(
73+
... model_client=client,
74+
... model_kwargs={
75+
... "model": "gpt-4",
76+
... "temperature": 0.7
77+
... }
78+
... )
79+
>>> response = generator({"input_str": "What is the capital of France?"})
80+
"""
81+
82+
def __init__(
83+
self,
84+
api_key: Optional[str] = None,
85+
api_version: Optional[str] = None,
86+
azure_endpoint: Optional[str] = None,
87+
base_url: Optional[str] = None,
88+
chat_completion_parser: Callable[[Completion], Any] = None,
89+
input_type: Literal["text", "messages"] = "text",
90+
):
91+
super().__init__()
92+
self._api_key = api_key
93+
self._api_version = api_version
94+
self._azure_endpoint = azure_endpoint
95+
self._base_url = base_url
96+
self.sync_client = self.init_sync_client()
97+
self.async_client = None
98+
self.chat_completion_parser = chat_completion_parser or get_first_message_content
99+
self._input_type = input_type
100+
101+
def _get_endpoint(self) -> str:
102+
"""Get the appropriate endpoint URL based on available configuration."""
103+
# First try the new format endpoint
104+
endpoint = self._azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
105+
if endpoint:
106+
return endpoint
107+
108+
# Then try the alternative base URL format
109+
base_url = self._base_url or os.getenv("AZURE_BASE_URL")
110+
if base_url:
111+
# If base_url is provided in the format https://<model>.<region>.models.ai.azure.com
112+
# we need to extract the model and region
113+
if "models.ai.azure.com" in base_url:
114+
return base_url.rstrip("/")
115+
# If it's just the model name, construct the full URL
116+
return f"https://{base_url}.openai.azure.com"
117+
118+
raise ValueError(
119+
"Either AZURE_OPENAI_ENDPOINT or AZURE_BASE_URL must be set. "
120+
"Check your deployment page for a URL like: "
121+
"https://<resource-name>.openai.azure.com/ or "
122+
"https://<model-deployment-name>.<region>.models.ai.azure.com"
123+
)
124+
125+
def init_sync_client(self):
126+
api_key = self._api_key or os.getenv("AZURE_OPENAI_API_KEY")
127+
api_version = self._api_version or os.getenv("AZURE_OPENAI_VERSION")
128+
129+
if not api_key:
130+
raise ValueError("Environment variable AZURE_OPENAI_API_KEY must be set")
131+
if not api_version:
132+
raise ValueError("Environment variable AZURE_OPENAI_VERSION must be set")
133+
134+
endpoint = self._get_endpoint()
135+
136+
return AzureOpenAI(
137+
api_key=api_key,
138+
api_version=api_version,
139+
azure_endpoint=endpoint
140+
)
141+
142+
def init_async_client(self):
143+
api_key = self._api_key or os.getenv("AZURE_OPENAI_API_KEY")
144+
api_version = self._api_version or os.getenv("AZURE_OPENAI_VERSION")
145+
146+
if not api_key:
147+
raise ValueError("Environment variable AZURE_OPENAI_API_KEY must be set")
148+
if not api_version:
149+
raise ValueError("Environment variable AZURE_OPENAI_VERSION must be set")
150+
151+
endpoint = self._get_endpoint()
152+
153+
return AsyncAzureOpenAI(
154+
api_key=api_key,
155+
api_version=api_version,
156+
azure_endpoint=endpoint
157+
)
158+
159+
def convert_inputs_to_api_kwargs(
160+
self,
161+
input: Optional[Any] = None,
162+
model_kwargs: Dict = {},
163+
model_type: ModelType = ModelType.UNDEFINED,
164+
) -> Dict:
165+
"""Convert inputs to Azure OpenAI API kwargs format."""
166+
final_model_kwargs = model_kwargs.copy()
167+
168+
if model_type == ModelType.EMBEDDER:
169+
if isinstance(input, str):
170+
input = [input]
171+
assert isinstance(input, (list, tuple)), "input must be a sequence of text"
172+
final_model_kwargs["input"] = input
173+
elif model_type == ModelType.LLM:
174+
messages = []
175+
if input is not None and input != "":
176+
if self._input_type == "text":
177+
messages.append({"role": "system", "content": input})
178+
else:
179+
messages.extend(input)
180+
final_model_kwargs["messages"] = messages
181+
else:
182+
raise ValueError(f"model_type {model_type} is not supported")
183+
184+
# Ensure model is specified
185+
if "model" not in final_model_kwargs:
186+
raise ValueError("model must be specified")
187+
188+
return final_model_kwargs
189+
190+
def parse_chat_completion(self, completion: ChatCompletion) -> GeneratorOutput:
191+
"""Parse chat completion response."""
192+
log.debug(f"completion: {completion}")
193+
try:
194+
data = self.chat_completion_parser(completion)
195+
usage = self.track_completion_usage(completion)
196+
return GeneratorOutput(data=None, usage=usage, raw_response=data)
197+
except Exception as e:
198+
log.error(f"Error parsing completion: {e}")
199+
return GeneratorOutput(
200+
data=None, error=str(e), raw_response=str(completion)
201+
)
202+
203+
def track_completion_usage(self, completion: ChatCompletion) -> CompletionUsage:
204+
"""Track completion token usage."""
205+
usage = completion.usage
206+
return CompletionUsage(
207+
completion_tokens=usage.completion_tokens,
208+
prompt_tokens=usage.prompt_tokens,
209+
total_tokens=usage.total_tokens,
210+
)
211+
212+
@backoff.on_exception(
213+
backoff.expo,
214+
(
215+
APITimeoutError,
216+
InternalServerError,
217+
RateLimitError,
218+
UnprocessableEntityError,
219+
BadRequestError,
220+
),
221+
max_time=5,
222+
)
223+
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
224+
"""Make a synchronous call to Azure OpenAI API."""
225+
log.info(f"api_kwargs: {api_kwargs}")
226+
if model_type == ModelType.EMBEDDER:
227+
return self.sync_client.embeddings.create(**api_kwargs)
228+
elif model_type == ModelType.LLM:
229+
if "stream" in api_kwargs and api_kwargs.get("stream", False):
230+
log.debug("streaming call")
231+
self.chat_completion_parser = handle_streaming_response
232+
return self.sync_client.chat.completions.create(**api_kwargs)
233+
return self.sync_client.chat.completions.create(**api_kwargs)
234+
else:
235+
raise ValueError(f"model_type {model_type} is not supported")
236+
237+
@backoff.on_exception(
238+
backoff.expo,
239+
(
240+
APITimeoutError,
241+
InternalServerError,
242+
RateLimitError,
243+
UnprocessableEntityError,
244+
BadRequestError,
245+
),
246+
max_time=5,
247+
)
248+
async def acall(
249+
self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED
250+
):
251+
"""Make an asynchronous call to Azure OpenAI API."""
252+
if self.async_client is None:
253+
self.async_client = self.init_async_client()
254+
if model_type == ModelType.EMBEDDER:
255+
return await self.async_client.embeddings.create(**api_kwargs)
256+
elif model_type == ModelType.LLM:
257+
return await self.async_client.chat.completions.create(**api_kwargs)
258+
else:
259+
raise ValueError(f"model_type {model_type} is not supported")
260+
261+
@classmethod
262+
def from_dict(cls, data: Dict[str, Any]) -> 'AzureClient':
263+
"""Create an instance from a dictionary."""
264+
obj = super().from_dict(data)
265+
obj.sync_client = obj.init_sync_client()
266+
obj.async_client = obj.init_async_client()
267+
return obj
268+
269+
def to_dict(self) -> Dict[str, Any]:
270+
"""Convert the instance to a dictionary."""
271+
exclude = ["sync_client", "async_client"]
272+
output = super().to_dict(exclude=exclude)
273+
return output

0 commit comments

Comments
 (0)