|
3 | 3 | import azure.identity
|
4 | 4 | import openai
|
5 | 5 | from dotenv import load_dotenv
|
6 |
| -from messages import MESSAGE_COUNTS |
| 6 | +from messages import MESSAGE_COUNTS # type: ignore[import-not-found] |
7 | 7 |
|
8 | 8 | # Setup the OpenAI client to use either Azure OpenAI or OpenAI API
|
9 | 9 | load_dotenv()
|
10 | 10 | API_HOST = os.getenv("API_HOST")
|
11 | 11 |
|
| 12 | +client: openai.OpenAI | openai.AzureOpenAI |
| 13 | + |
12 | 14 | if API_HOST == "azure":
|
| 15 | + if (azure_openai_version := os.getenv("AZURE_OPENAI_VERSION")) is None: |
| 16 | + raise ValueError("Missing Azure OpenAI version") |
| 17 | + if (azure_openai_endpoint := os.getenv("AZURE_OPENAI_ENDPOINT")) is None: |
| 18 | + raise ValueError("Missing Azure OpenAI endpoint") |
| 19 | + if (azure_openai_deployment := os.getenv("AZURE_OPENAI_DEPLOYMENT")) is None: |
| 20 | + raise ValueError("Missing Azure OpenAI deployment") |
| 21 | + |
13 | 22 | token_provider = azure.identity.get_bearer_token_provider(
|
14 | 23 | azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
15 | 24 | )
|
16 | 25 | client = openai.AzureOpenAI(
|
17 |
| - api_version=os.getenv("AZURE_OPENAI_VERSION"), |
18 |
| - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), |
| 26 | + api_version=azure_openai_version, |
| 27 | + azure_endpoint=azure_openai_endpoint, |
19 | 28 | azure_ad_token_provider=token_provider,
|
20 | 29 | )
|
21 |
| - MODEL_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT") |
| 30 | + MODEL_NAME = azure_openai_deployment |
22 | 31 | else:
|
23 |
| - client = openai.OpenAI(api_key=os.getenv("OPENAI_KEY")) |
24 |
| - MODEL_NAME = os.getenv("OPENAI_MODEL") |
| 32 | + if (openai_key := os.getenv("OPENAI_KEY")) is None: |
| 33 | + raise ValueError("Missing OpenAI API key") |
| 34 | + if (openai_model := os.getenv("OPENAI_MODEL")) is None: |
| 35 | + raise ValueError("Missing OpenAI model") |
| 36 | + client = openai.OpenAI(api_key=openai_key) |
| 37 | + MODEL_NAME = openai_model |
25 | 38 |
|
26 | 39 | # Test the token count for each message
|
27 | 40 | for message_count_pair in MESSAGE_COUNTS:
|
28 | 41 | response = client.chat.completions.create(
|
29 | 42 | model=MODEL_NAME,
|
30 | 43 | temperature=0.7,
|
31 | 44 | n=1,
|
32 |
| - messages=[message_count_pair["message"]], |
| 45 | + messages=[message_count_pair["message"]], # type: ignore[list-item] |
33 | 46 | )
|
34 | 47 |
|
35 | 48 | print(message_count_pair["message"])
|
36 | 49 | expected_tokens = message_count_pair["count"]
|
| 50 | + assert response.usage is not None, "Expected usage to be present" |
37 | 51 | assert (
|
38 | 52 | response.usage.prompt_tokens == expected_tokens
|
39 | 53 | ), f"Expected {expected_tokens} tokens, got {response.usage.prompt_tokens}"
|
0 commit comments