@@ -129,67 +129,12 @@ class AzureAIClient(ModelClient):
129129 authentication. It is recommended to set environment variables for sensitive data like API keys.
130130
131131 Args:
132- api_key (Optional[str]): Azure OpenAI API key. Default is None.
133- api_version (Optional[str]): API version to use. Default is None.
134- azure_endpoint (Optional[str]): Azure OpenAI endpoint URL. Default is None.
135- credential (Optional[DefaultAzureCredential]): Azure AD credential for token-based authentication. Default is None.
136- chat_completion_parser (Callable[[Completion], Any]): Function to parse chat completions. Default is `get_first_message_content`.
137- input_type (Literal["text", "messages"]): Format for input, either "text" or "messages". Default is "text".
138-
139- **Setup Instructions:**
140-
141- - **Using API Key:**
142- Set up the following environment variables:
143- ```bash
144- export AZURE_OPENAI_API_KEY="your_api_key"
145- export AZURE_OPENAI_ENDPOINT="your_endpoint"
146- export AZURE_OPENAI_VERSION="your_version"
147- ```
148-
149- - **Using Azure AD Token:**
150- Ensure you have configured Azure AD credentials. The `DefaultAzureCredential` will automatically use your configured credentials.
151-
152- **Example Usage:**
153-
154- .. code-block:: python
155-
156- from azure.identity import DefaultAzureCredential
157- from your_module import AzureAIClient # Adjust import based on your module name
158-
159- # Initialize with API key
160- client = AzureAIClient(
161- api_key="your_api_key",
162- api_version="2023-05-15",
163- azure_endpoint="https://your-endpoint.openai.azure.com/"
164- )
165-
166- # Or initialize with Azure AD token
167- client = AzureAIClient(
168- api_version="2023-05-15",
169- azure_endpoint="https://your-endpoint.openai.azure.com/",
170- credential=DefaultAzureCredential()
171- )
172-
173- # Example call to the chat completion API
174- api_kwargs = {
175- "model": "gpt-3.5-turbo",
176- "messages": [{"role": "user", "content": "What is the meaning of life?"}],
177- "stream": True
178- }
179- response = client.call(api_kwargs=api_kwargs, model_type=ModelType.LLM)
180-
181- for chunk in response:
182- print(chunk)
183-
184-
185- **Notes:**
186- - Ensure that the API key or credentials are correctly set up and accessible to avoid authentication errors.
187- - Use `chat_completion_parser` to define how to extract and handle the chat completion responses.
188- - The `input_type` parameter determines how input is formatted for the API call.
189-
190- **References:**
191- - [Azure OpenAI API Documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
192- - [OpenAI API Documentation](https://platform.openai.com/docs/guides/text-generation)
132+ api_key: Azure OpenAI API key.
133+ api_version: Azure OpenAI API version.
134+ azure_endpoint: Azure OpenAI endpoint.
135+ credential: Azure AD credential for token-based authentication.
136+ chat_completion_parser: Function to parse chat completions.
137+ input_type: Input format, either "text" or "messages".
193138 """
194139
195140 def __init__ (
@@ -201,22 +146,11 @@ def __init__(
201146 chat_completion_parser : Callable [[Completion ], Any ] = None ,
202147 input_type : Literal ["text" , "messages" ] = "text" ,
203148 ):
204- r"""It is recommended to set the API_KEY into the environment variable instead of passing it as an argument.
205-
206-
207- Initializes the Azure OpenAI client with either API key or AAD token authentication.
208-
209- Args:
210- api_key: Azure OpenAI API key.
211- api_version: Azure OpenAI API version.
212- azure_endpoint: Azure OpenAI endpoint.
213- credential: Azure AD credential for token-based authentication.
214- chat_completion_parser: Function to parse chat completions.
215- input_type: Input format, either "text" or "messages".
216-
217- """
218149 super ().__init__ ()
219150
151+ # Model type will be set dynamically based on the operation
152+ self ._model_type = None
153+
220154 # added api_type azure for azure Ai
221155 self .api_type = "azure"
222156 self ._api_key = api_key
@@ -230,6 +164,16 @@ def __init__(
230164 )
231165 self ._input_type = input_type
232166
167+ @property
168+ def model_type (self ) -> ModelType :
169+ """Get the current model type. Defaults to LLM if not set."""
170+ return self ._model_type or ModelType .LLM
171+
172+ @model_type .setter
173+ def model_type (self , value : ModelType ):
174+ """Set the model type."""
175+ self ._model_type = value
176+
233177 def init_sync_client (self ):
234178 api_key = self ._api_key or os .getenv ("AZURE_OPENAI_API_KEY" )
235179 azure_endpoint = self ._azure_endpoint or os .getenv ("AZURE_OPENAI_ENDPOINT" )
@@ -357,6 +301,10 @@ def convert_inputs_to_api_kwargs(
357301 """
358302
359303 final_model_kwargs = model_kwargs .copy ()
304+ # If model_type is UNDEFINED, use the current model_type property
305+ if model_type == ModelType .UNDEFINED :
306+ model_type = self .model_type
307+
360308 if model_type == ModelType .EMBEDDER :
361309 if isinstance (input , str ):
362310 input = [input ]
@@ -383,14 +331,13 @@ def convert_inputs_to_api_kwargs(
383331 if match :
384332 system_prompt = match .group (1 )
385333 input_str = match .group (2 )
386-
387334 else :
388335 print ("No match found." )
389336 if system_prompt and input_str :
390337 messages .append ({"role" : "system" , "content" : system_prompt })
391338 messages .append ({"role" : "user" , "content" : input_str })
392339 if len (messages ) == 0 :
393- messages .append ({"role" : "system " , "content" : input })
340+ messages .append ({"role" : "user " , "content" : input })
394341 final_model_kwargs ["messages" ] = messages
395342 else :
396343 raise ValueError (f"model_type { model_type } is not supported" )
@@ -409,8 +356,13 @@ def convert_inputs_to_api_kwargs(
409356 )
410357 def call (self , api_kwargs : Dict = {}, model_type : ModelType = ModelType .UNDEFINED ):
411358 """
412- kwargs is the combined input and model_kwargs. Support streaming call.
359+ kwargs is the combined input and model_kwargs. Support streaming call.
360+ Also updates the internal model_type based on the operation.
413361 """
362+ # Update internal model type based on the operation
363+ if model_type != ModelType .UNDEFINED :
364+ self .model_type = model_type
365+
414366 log .info (f"api_kwargs: { api_kwargs } " )
415367 if model_type == ModelType .EMBEDDER :
416368 return self .sync_client .embeddings .create (** api_kwargs )
0 commit comments