Skip to content

Commit 1465570

Browse files
authored
Merge pull request #8 from pamelafox/typed
Add py.typed stub
2 parents 18a7db3 + 3a320eb commit 1465570

File tree

8 files changed

+64
-18
lines changed

8 files changed

+64
-18
lines changed

.github/workflows/python.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@ jobs:
3131
- name: Run unit tests
3232
run: |
3333
python3 -m pytest -s -vv --cov --cov-fail-under=99
34+
- name: Run type checks
35+
run: mypy .

.vscode/settings.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,12 @@
33
"tests"
44
],
55
"python.testing.unittestEnabled": false,
6-
"python.testing.pytestEnabled": true
6+
"python.testing.pytestEnabled": true,
7+
"files.exclude": {
8+
".coverage": true,
9+
".pytest_cache": true,
10+
"__pycache__": true,
11+
".ruff_cache": true,
12+
".mypy_cache": true,
13+
}
714
}

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.1.1] - May 2, 2024
6+
7+
- Add `py.typed` file so that mypy can find the type hints in this package.
8+
59
## [0.1.0] - May 2, 2024
610

711
- Add `count_tokens_for_system_and_tools` to count tokens for system message and tools. You should count the tokens for both together, since the token count for tools varies based off whether a system message is provided.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "openai-messages-token-helper"
33
description = "A helper library for estimating tokens used by messages sent through OpenAI Chat Completions API."
4-
version = "0.1.0"
4+
version = "0.1.1"
55
authors = [{name = "Pamela Fox"}]
66
requires-python = ">=3.9"
77
readme = "README.md"
@@ -33,7 +33,8 @@ dev = [
3333
"black",
3434
"flit",
3535
"azure-identity",
36-
"python-dotenv"
36+
"python-dotenv",
37+
"mypy"
3738
]
3839

3940
[build-system]

src/openai_messages_token_helper/model_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def count_tokens_for_message(model: str, message: Mapping[str, object], default_
113113

114114
def count_tokens_for_system_and_tools(
115115
model: str,
116-
system_message: dict | None = None,
116+
system_message: Mapping[str, object] | None = None,
117117
tools: list[dict[str, dict]] | None = None,
118118
tool_choice: str | dict | None = None,
119119
default_to_cl100k: bool = False,

src/py.typed

Whitespace-only changes.

tests/verify_functions.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,47 @@
11
import os
2+
from typing import Union
23

34
import azure.identity
45
import openai
56
from dotenv import load_dotenv
6-
from functions import FUNCTION_COUNTS
7+
from functions import FUNCTION_COUNTS # type: ignore[import-not-found]
78

89
# Setup the OpenAI client to use either Azure OpenAI or OpenAI API
910
load_dotenv()
1011
API_HOST = os.getenv("API_HOST")
1112

13+
client: Union[openai.OpenAI, openai.AzureOpenAI]
14+
1215
if API_HOST == "azure":
16+
17+
if (azure_openai_version := os.getenv("AZURE_OPENAI_VERSION")) is None:
18+
raise ValueError("Missing Azure OpenAI version")
19+
if (azure_openai_endpoint := os.getenv("AZURE_OPENAI_ENDPOINT")) is None:
20+
raise ValueError("Missing Azure OpenAI endpoint")
21+
if (azure_openai_deployment := os.getenv("AZURE_OPENAI_DEPLOYMENT")) is None:
22+
raise ValueError("Missing Azure OpenAI deployment")
23+
1324
token_provider = azure.identity.get_bearer_token_provider(
1425
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
1526
)
1627
client = openai.AzureOpenAI(
17-
api_version=os.getenv("AZURE_OPENAI_VERSION"),
18-
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
28+
api_version=azure_openai_version,
29+
azure_endpoint=azure_openai_endpoint,
1930
azure_ad_token_provider=token_provider,
2031
)
21-
MODEL_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT")
32+
MODEL_NAME = azure_openai_deployment
2233
else:
23-
client = openai.OpenAI(api_key=os.getenv("OPENAI_KEY"))
24-
MODEL_NAME = os.getenv("OPENAI_MODEL")
34+
if (openai_key := os.getenv("OPENAI_KEY")) is None:
35+
raise ValueError("Missing OpenAI API key")
36+
if (openai_model := os.getenv("OPENAI_MODEL")) is None:
37+
raise ValueError("Missing OpenAI model")
38+
client = openai.OpenAI(api_key=openai_key)
39+
MODEL_NAME = openai_model
40+
2541

2642
# Test the token count for each message
2743
for function_count_pair in FUNCTION_COUNTS:
28-
response = client.chat.completions.create(
44+
response = client.chat.completions.create( # type: ignore[call-overload]
2945
model=MODEL_NAME,
3046
temperature=0.7,
3147
n=1,
@@ -35,6 +51,7 @@
3551
)
3652

3753
print(function_count_pair["tools"])
54+
assert response.usage is not None, "Expected usage to be present"
3855
assert (
3956
response.usage.prompt_tokens == function_count_pair["count"]
4057
), f"Expected {function_count_pair['count']} tokens, got {response.usage.prompt_tokens}"

tests/verify_openai.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,54 @@
11
import os
2+
from typing import Union
23

34
import azure.identity
45
import openai
56
from dotenv import load_dotenv
6-
from messages import MESSAGE_COUNTS
7+
from messages import MESSAGE_COUNTS # type: ignore[import-not-found]
78

89
# Setup the OpenAI client to use either Azure OpenAI or OpenAI API
910
load_dotenv()
1011
API_HOST = os.getenv("API_HOST")
1112

13+
client: Union[openai.OpenAI, openai.AzureOpenAI]
14+
1215
if API_HOST == "azure":
16+
if (azure_openai_version := os.getenv("AZURE_OPENAI_VERSION")) is None:
17+
raise ValueError("Missing Azure OpenAI version")
18+
if (azure_openai_endpoint := os.getenv("AZURE_OPENAI_ENDPOINT")) is None:
19+
raise ValueError("Missing Azure OpenAI endpoint")
20+
if (azure_openai_deployment := os.getenv("AZURE_OPENAI_DEPLOYMENT")) is None:
21+
raise ValueError("Missing Azure OpenAI deployment")
22+
1323
token_provider = azure.identity.get_bearer_token_provider(
1424
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
1525
)
1626
client = openai.AzureOpenAI(
17-
api_version=os.getenv("AZURE_OPENAI_VERSION"),
18-
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
27+
api_version=azure_openai_version,
28+
azure_endpoint=azure_openai_endpoint,
1929
azure_ad_token_provider=token_provider,
2030
)
21-
MODEL_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT")
31+
MODEL_NAME = azure_openai_deployment
2232
else:
23-
client = openai.OpenAI(api_key=os.getenv("OPENAI_KEY"))
24-
MODEL_NAME = os.getenv("OPENAI_MODEL")
33+
if (openai_key := os.getenv("OPENAI_KEY")) is None:
34+
raise ValueError("Missing OpenAI API key")
35+
if (openai_model := os.getenv("OPENAI_MODEL")) is None:
36+
raise ValueError("Missing OpenAI model")
37+
client = openai.OpenAI(api_key=openai_key)
38+
MODEL_NAME = openai_model
2539

2640
# Test the token count for each message
2741
for message_count_pair in MESSAGE_COUNTS:
2842
response = client.chat.completions.create(
2943
model=MODEL_NAME,
3044
temperature=0.7,
3145
n=1,
32-
messages=[message_count_pair["message"]],
46+
messages=[message_count_pair["message"]], # type: ignore[list-item]
3347
)
3448

3549
print(message_count_pair["message"])
3650
expected_tokens = message_count_pair["count"]
51+
assert response.usage is not None, "Expected usage to be present"
3752
assert (
3853
response.usage.prompt_tokens == expected_tokens
3954
), f"Expected {expected_tokens} tokens, got {response.usage.prompt_tokens}"

0 commit comments

Comments
 (0)