1+ """Azure OpenAI ModelClient integration."""
2+
3+ import os
4+ from typing import Dict , Optional , Any , Callable , Literal
5+ import backoff
6+ import logging
7+
8+ from adalflow .core .model_client import ModelClient
9+ from adalflow .core .types import ModelType , CompletionUsage , GeneratorOutput
10+
11+ # optional import
12+ from adalflow .utils .lazy_import import safe_import , OptionalPackages
13+
14+ openai = safe_import (OptionalPackages .OPENAI .value [0 ], OptionalPackages .OPENAI .value [1 ])
15+
16+ from openai import AzureOpenAI , AsyncAzureOpenAI , Stream
17+ from openai import (
18+ APITimeoutError ,
19+ InternalServerError ,
20+ RateLimitError ,
21+ UnprocessableEntityError ,
22+ BadRequestError ,
23+ )
24+ from openai .types import (
25+ Completion ,
26+ CreateEmbeddingResponse ,
27+ )
28+ from openai .types .chat import ChatCompletionChunk , ChatCompletion
29+ from adalflow .components .model_client .utils import parse_embedding_response
30+
31+ log = logging .getLogger (__name__ )
32+
33+ def get_first_message_content (completion : ChatCompletion ) -> str :
34+ """When we only need the content of the first message.
35+ It is the default parser for chat completion."""
36+ return completion .choices [0 ].message .content
37+
38+ def parse_stream_response (completion : ChatCompletionChunk ) -> str :
39+ """Parse the response of the stream API."""
40+ return completion .choices [0 ].delta .content
41+
42+ def handle_streaming_response (generator : Stream [ChatCompletionChunk ]):
43+ """Handle the streaming response."""
44+ for completion in generator :
45+ log .debug (f"Raw chunk completion: { completion } " )
46+ parsed_content = parse_stream_response (completion )
47+ yield parsed_content
48+
49+ class AzureClient (ModelClient ):
50+ """A component wrapper for the Azure OpenAI API client.
51+
52+ This client supports both chat completion and embedding APIs through Azure OpenAI.
53+ It can be used with both sync and async operations.
54+
55+ Args:
56+ api_key (Optional[str]): Azure OpenAI API key
57+ api_version (Optional[str]): API version to use
58+ azure_endpoint (Optional[str]): Azure OpenAI endpoint URL (e.g., https://<resource-name>.openai.azure.com/)
59+ base_url (Optional[str]): Alternative base URL format (e.g., https://<model-deployment-name>.<region>.models.ai.azure.com)
60+ chat_completion_parser (Optional[Callable]): Function to parse chat completions
61+ input_type (Literal["text", "messages"]): Format for input
62+
63+ Environment Variables:
64+ AZURE_OPENAI_API_KEY: API key
65+ AZURE_OPENAI_ENDPOINT: Endpoint URL (new format)
66+ AZURE_BASE_URL: Base URL (alternative format)
67+ AZURE_OPENAI_VERSION: API version
68+
69+ Example:
70+ >>> from adalflow.components.model_client import AzureClient
71+ >>> client = AzureClient()
72+ >>> generator = Generator(
73+ ... model_client=client,
74+ ... model_kwargs={
75+ ... "model": "gpt-4",
76+ ... "temperature": 0.7
77+ ... }
78+ ... )
79+ >>> response = generator({"input_str": "What is the capital of France?"})
80+ """
81+
82+ def __init__ (
83+ self ,
84+ api_key : Optional [str ] = None ,
85+ api_version : Optional [str ] = None ,
86+ azure_endpoint : Optional [str ] = None ,
87+ base_url : Optional [str ] = None ,
88+ chat_completion_parser : Callable [[Completion ], Any ] = None ,
89+ input_type : Literal ["text" , "messages" ] = "text" ,
90+ ):
91+ super ().__init__ ()
92+ self ._api_key = api_key
93+ self ._api_version = api_version
94+ self ._azure_endpoint = azure_endpoint
95+ self ._base_url = base_url
96+ self .sync_client = self .init_sync_client ()
97+ self .async_client = None
98+ self .chat_completion_parser = chat_completion_parser or get_first_message_content
99+ self ._input_type = input_type
100+
101+ def _get_endpoint (self ) -> str :
102+ """Get the appropriate endpoint URL based on available configuration."""
103+ # First try the new format endpoint
104+ endpoint = self ._azure_endpoint or os .getenv ("AZURE_OPENAI_ENDPOINT" )
105+ if endpoint :
106+ return endpoint
107+
108+ # Then try the alternative base URL format
109+ base_url = self ._base_url or os .getenv ("AZURE_BASE_URL" )
110+ if base_url :
111+ # If base_url is provided in the format https://<model>.<region>.models.ai.azure.com
112+ # we need to extract the model and region
113+ if "models.ai.azure.com" in base_url :
114+ return base_url .rstrip ("/" )
115+ # If it's just the model name, construct the full URL
116+ return f"https://{ base_url } .openai.azure.com"
117+
118+ raise ValueError (
119+ "Either AZURE_OPENAI_ENDPOINT or AZURE_BASE_URL must be set. "
120+ "Check your deployment page for a URL like: "
121+ "https://<resource-name>.openai.azure.com/ or "
122+ "https://<model-deployment-name>.<region>.models.ai.azure.com"
123+ )
124+
125+ def init_sync_client (self ):
126+ api_key = self ._api_key or os .getenv ("AZURE_OPENAI_API_KEY" )
127+ api_version = self ._api_version or os .getenv ("AZURE_OPENAI_VERSION" )
128+
129+ if not api_key :
130+ raise ValueError ("Environment variable AZURE_OPENAI_API_KEY must be set" )
131+ if not api_version :
132+ raise ValueError ("Environment variable AZURE_OPENAI_VERSION must be set" )
133+
134+ endpoint = self ._get_endpoint ()
135+
136+ return AzureOpenAI (
137+ api_key = api_key ,
138+ api_version = api_version ,
139+ azure_endpoint = endpoint
140+ )
141+
142+ def init_async_client (self ):
143+ api_key = self ._api_key or os .getenv ("AZURE_OPENAI_API_KEY" )
144+ api_version = self ._api_version or os .getenv ("AZURE_OPENAI_VERSION" )
145+
146+ if not api_key :
147+ raise ValueError ("Environment variable AZURE_OPENAI_API_KEY must be set" )
148+ if not api_version :
149+ raise ValueError ("Environment variable AZURE_OPENAI_VERSION must be set" )
150+
151+ endpoint = self ._get_endpoint ()
152+
153+ return AsyncAzureOpenAI (
154+ api_key = api_key ,
155+ api_version = api_version ,
156+ azure_endpoint = endpoint
157+ )
158+
159+ def convert_inputs_to_api_kwargs (
160+ self ,
161+ input : Optional [Any ] = None ,
162+ model_kwargs : Dict = {},
163+ model_type : ModelType = ModelType .UNDEFINED ,
164+ ) -> Dict :
165+ """Convert inputs to Azure OpenAI API kwargs format."""
166+ final_model_kwargs = model_kwargs .copy ()
167+
168+ if model_type == ModelType .EMBEDDER :
169+ if isinstance (input , str ):
170+ input = [input ]
171+ assert isinstance (input , (list , tuple )), "input must be a sequence of text"
172+ final_model_kwargs ["input" ] = input
173+ elif model_type == ModelType .LLM :
174+ messages = []
175+ if input is not None and input != "" :
176+ if self ._input_type == "text" :
177+ messages .append ({"role" : "system" , "content" : input })
178+ else :
179+ messages .extend (input )
180+ final_model_kwargs ["messages" ] = messages
181+ else :
182+ raise ValueError (f"model_type { model_type } is not supported" )
183+
184+ # Ensure model is specified
185+ if "model" not in final_model_kwargs :
186+ raise ValueError ("model must be specified" )
187+
188+ return final_model_kwargs
189+
190+ def parse_chat_completion (self , completion : ChatCompletion ) -> GeneratorOutput :
191+ """Parse chat completion response."""
192+ log .debug (f"completion: { completion } " )
193+ try :
194+ data = self .chat_completion_parser (completion )
195+ usage = self .track_completion_usage (completion )
196+ return GeneratorOutput (data = None , usage = usage , raw_response = data )
197+ except Exception as e :
198+ log .error (f"Error parsing completion: { e } " )
199+ return GeneratorOutput (
200+ data = None , error = str (e ), raw_response = str (completion )
201+ )
202+
203+ def track_completion_usage (self , completion : ChatCompletion ) -> CompletionUsage :
204+ """Track completion token usage."""
205+ usage = completion .usage
206+ return CompletionUsage (
207+ completion_tokens = usage .completion_tokens ,
208+ prompt_tokens = usage .prompt_tokens ,
209+ total_tokens = usage .total_tokens ,
210+ )
211+
212+ @backoff .on_exception (
213+ backoff .expo ,
214+ (
215+ APITimeoutError ,
216+ InternalServerError ,
217+ RateLimitError ,
218+ UnprocessableEntityError ,
219+ BadRequestError ,
220+ ),
221+ max_time = 5 ,
222+ )
223+ def call (self , api_kwargs : Dict = {}, model_type : ModelType = ModelType .UNDEFINED ):
224+ """Make a synchronous call to Azure OpenAI API."""
225+ log .info (f"api_kwargs: { api_kwargs } " )
226+ if model_type == ModelType .EMBEDDER :
227+ return self .sync_client .embeddings .create (** api_kwargs )
228+ elif model_type == ModelType .LLM :
229+ if "stream" in api_kwargs and api_kwargs .get ("stream" , False ):
230+ log .debug ("streaming call" )
231+ self .chat_completion_parser = handle_streaming_response
232+ return self .sync_client .chat .completions .create (** api_kwargs )
233+ return self .sync_client .chat .completions .create (** api_kwargs )
234+ else :
235+ raise ValueError (f"model_type { model_type } is not supported" )
236+
237+ @backoff .on_exception (
238+ backoff .expo ,
239+ (
240+ APITimeoutError ,
241+ InternalServerError ,
242+ RateLimitError ,
243+ UnprocessableEntityError ,
244+ BadRequestError ,
245+ ),
246+ max_time = 5 ,
247+ )
248+ async def acall (
249+ self , api_kwargs : Dict = {}, model_type : ModelType = ModelType .UNDEFINED
250+ ):
251+ """Make an asynchronous call to Azure OpenAI API."""
252+ if self .async_client is None :
253+ self .async_client = self .init_async_client ()
254+ if model_type == ModelType .EMBEDDER :
255+ return await self .async_client .embeddings .create (** api_kwargs )
256+ elif model_type == ModelType .LLM :
257+ return await self .async_client .chat .completions .create (** api_kwargs )
258+ else :
259+ raise ValueError (f"model_type { model_type } is not supported" )
260+
261+ @classmethod
262+ def from_dict (cls , data : Dict [str , Any ]) -> 'AzureClient' :
263+ """Create an instance from a dictionary."""
264+ obj = super ().from_dict (data )
265+ obj .sync_client = obj .init_sync_client ()
266+ obj .async_client = obj .init_async_client ()
267+ return obj
268+
269+ def to_dict (self ) -> Dict [str , Any ]:
270+ """Convert the instance to a dictionary."""
271+ exclude = ["sync_client" , "async_client" ]
272+ output = super ().to_dict (exclude = exclude )
273+ return output
0 commit comments