Skip to content

Commit f381e91

Browse files
committed
uncommit unrelated files
1 parent 98177d9 commit f381e91

File tree

7 files changed

+444
-139
lines changed

7 files changed

+444
-139
lines changed

adalflow/adalflow/components/model_client/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@
6565
OptionalPackages.OPENAI,
6666
)
6767

68+
# Azure OpenAI Client
69+
AzureAIClient = LazyImport(
70+
"adalflow.components.model_client.azureai_client.AzureAIClient",
71+
OptionalPackages.AZURE,
72+
)
73+
6874
__all__ = [
6975
"CohereAPIClient",
7076
"TransformerReranker",
@@ -76,6 +82,7 @@
7682
"GroqAPIClient",
7783
"OpenAIClient",
7884
"GoogleGenAIClient",
85+
"AzureAIClient",
7986
]
8087

8188
for name in __all__:

adalflow/adalflow/components/model_client/azureai_client.py

Lines changed: 30 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -129,67 +129,12 @@ class AzureAIClient(ModelClient):
129129
authentication. It is recommended to set environment variables for sensitive data like API keys.
130130
131131
Args:
132-
api_key (Optional[str]): Azure OpenAI API key. Default is None.
133-
api_version (Optional[str]): API version to use. Default is None.
134-
azure_endpoint (Optional[str]): Azure OpenAI endpoint URL. Default is None.
135-
credential (Optional[DefaultAzureCredential]): Azure AD credential for token-based authentication. Default is None.
136-
chat_completion_parser (Callable[[Completion], Any]): Function to parse chat completions. Default is `get_first_message_content`.
137-
input_type (Literal["text", "messages"]): Format for input, either "text" or "messages". Default is "text".
138-
139-
**Setup Instructions:**
140-
141-
- **Using API Key:**
142-
Set up the following environment variables:
143-
```bash
144-
export AZURE_OPENAI_API_KEY="your_api_key"
145-
export AZURE_OPENAI_ENDPOINT="your_endpoint"
146-
export AZURE_OPENAI_VERSION="your_version"
147-
```
148-
149-
- **Using Azure AD Token:**
150-
Ensure you have configured Azure AD credentials. The `DefaultAzureCredential` will automatically use your configured credentials.
151-
152-
**Example Usage:**
153-
154-
.. code-block:: python
155-
156-
from azure.identity import DefaultAzureCredential
157-
from your_module import AzureAIClient # Adjust import based on your module name
158-
159-
# Initialize with API key
160-
client = AzureAIClient(
161-
api_key="your_api_key",
162-
api_version="2023-05-15",
163-
azure_endpoint="https://your-endpoint.openai.azure.com/"
164-
)
165-
166-
# Or initialize with Azure AD token
167-
client = AzureAIClient(
168-
api_version="2023-05-15",
169-
azure_endpoint="https://your-endpoint.openai.azure.com/",
170-
credential=DefaultAzureCredential()
171-
)
172-
173-
# Example call to the chat completion API
174-
api_kwargs = {
175-
"model": "gpt-3.5-turbo",
176-
"messages": [{"role": "user", "content": "What is the meaning of life?"}],
177-
"stream": True
178-
}
179-
response = client.call(api_kwargs=api_kwargs, model_type=ModelType.LLM)
180-
181-
for chunk in response:
182-
print(chunk)
183-
184-
185-
**Notes:**
186-
- Ensure that the API key or credentials are correctly set up and accessible to avoid authentication errors.
187-
- Use `chat_completion_parser` to define how to extract and handle the chat completion responses.
188-
- The `input_type` parameter determines how input is formatted for the API call.
189-
190-
**References:**
191-
- [Azure OpenAI API Documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
192-
- [OpenAI API Documentation](https://platform.openai.com/docs/guides/text-generation)
132+
api_key: Azure OpenAI API key.
133+
api_version: Azure OpenAI API version.
134+
azure_endpoint: Azure OpenAI endpoint.
135+
credential: Azure AD credential for token-based authentication.
136+
chat_completion_parser: Function to parse chat completions.
137+
input_type: Input format, either "text" or "messages".
193138
"""
194139

195140
def __init__(
@@ -201,22 +146,11 @@ def __init__(
201146
chat_completion_parser: Callable[[Completion], Any] = None,
202147
input_type: Literal["text", "messages"] = "text",
203148
):
204-
r"""It is recommended to set the API_KEY into the environment variable instead of passing it as an argument.
205-
206-
207-
Initializes the Azure OpenAI client with either API key or AAD token authentication.
208-
209-
Args:
210-
api_key: Azure OpenAI API key.
211-
api_version: Azure OpenAI API version.
212-
azure_endpoint: Azure OpenAI endpoint.
213-
credential: Azure AD credential for token-based authentication.
214-
chat_completion_parser: Function to parse chat completions.
215-
input_type: Input format, either "text" or "messages".
216-
217-
"""
218149
super().__init__()
219150

151+
# Model type will be set dynamically based on the operation
152+
self._model_type = None
153+
220154
# added api_type azure for azure Ai
221155
self.api_type = "azure"
222156
self._api_key = api_key
@@ -230,6 +164,16 @@ def __init__(
230164
)
231165
self._input_type = input_type
232166

167+
@property
168+
def model_type(self) -> ModelType:
169+
"""Get the current model type. Defaults to LLM if not set."""
170+
return self._model_type or ModelType.LLM
171+
172+
@model_type.setter
173+
def model_type(self, value: ModelType):
174+
"""Set the model type."""
175+
self._model_type = value
176+
233177
def init_sync_client(self):
234178
api_key = self._api_key or os.getenv("AZURE_OPENAI_API_KEY")
235179
azure_endpoint = self._azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
@@ -357,6 +301,10 @@ def convert_inputs_to_api_kwargs(
357301
"""
358302

359303
final_model_kwargs = model_kwargs.copy()
304+
# If model_type is UNDEFINED, use the current model_type property
305+
if model_type == ModelType.UNDEFINED:
306+
model_type = self.model_type
307+
360308
if model_type == ModelType.EMBEDDER:
361309
if isinstance(input, str):
362310
input = [input]
@@ -383,14 +331,13 @@ def convert_inputs_to_api_kwargs(
383331
if match:
384332
system_prompt = match.group(1)
385333
input_str = match.group(2)
386-
387334
else:
388335
print("No match found.")
389336
if system_prompt and input_str:
390337
messages.append({"role": "system", "content": system_prompt})
391338
messages.append({"role": "user", "content": input_str})
392339
if len(messages) == 0:
393-
messages.append({"role": "system", "content": input})
340+
messages.append({"role": "user", "content": input})
394341
final_model_kwargs["messages"] = messages
395342
else:
396343
raise ValueError(f"model_type {model_type} is not supported")
@@ -409,8 +356,13 @@ def convert_inputs_to_api_kwargs(
409356
)
410357
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
411358
"""
412-
kwargs is the combined input and model_kwargs. Support streaming call.
359+
kwargs is the combined input and model_kwargs. Support streaming call.
360+
Also updates the internal model_type based on the operation.
413361
"""
362+
# Update internal model type based on the operation
363+
if model_type != ModelType.UNDEFINED:
364+
self.model_type = model_type
365+
414366
log.info(f"api_kwargs: {api_kwargs}")
415367
if model_type == ModelType.EMBEDDER:
416368
return self.sync_client.embeddings.create(**api_kwargs)

0 commit comments

Comments
 (0)