Skip to content

Commit de25487

Browse files
committed
Add py.typed stub
1 parent 18a7db3 commit de25487

File tree

8 files changed

+62
-18
lines changed

8 files changed

+62
-18
lines changed

.github/workflows/python.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@ jobs:
3131
- name: Run unit tests
3232
run: |
3333
python3 -m pytest -s -vv --cov --cov-fail-under=99
34+
- name: Run type checks
35+
run: mypy .

.vscode/settings.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,12 @@
33
"tests"
44
],
55
"python.testing.unittestEnabled": false,
6-
"python.testing.pytestEnabled": true
6+
"python.testing.pytestEnabled": true,
7+
"files.exclude": {
8+
".coverage": true,
9+
".pytest_cache": true,
10+
"__pycache__": true,
11+
".ruff_cache": true,
12+
".mypy_cache": true,
13+
}
714
}

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.1.1] - May 2, 2024
6+
7+
- Add `py.typed` file so that mypy can find the type hints in this package.
8+
59
## [0.1.0] - May 2, 2024
610

711
- Add `count_tokens_for_system_and_tools` to count tokens for system message and tools. You should count the tokens for both together, since the token count for tools varies based off whether a system message is provided.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "openai-messages-token-helper"
33
description = "A helper library for estimating tokens used by messages sent through OpenAI Chat Completions API."
4-
version = "0.1.0"
4+
version = "0.1.1"
55
authors = [{name = "Pamela Fox"}]
66
requires-python = ">=3.9"
77
readme = "README.md"
@@ -33,7 +33,8 @@ dev = [
3333
"black",
3434
"flit",
3535
"azure-identity",
36-
"python-dotenv"
36+
"python-dotenv",
37+
"mypy"
3738
]
3839

3940
[build-system]

src/openai_messages_token_helper/model_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def count_tokens_for_message(model: str, message: Mapping[str, object], default_
113113

114114
def count_tokens_for_system_and_tools(
115115
model: str,
116-
system_message: dict | None = None,
116+
system_message: Mapping[str, object] | None = None,
117117
tools: list[dict[str, dict]] | None = None,
118118
tool_choice: str | dict | None = None,
119119
default_to_cl100k: bool = False,

src/py.typed

Whitespace-only changes.

tests/verify_functions.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,44 @@
33
import azure.identity
44
import openai
55
from dotenv import load_dotenv
6-
from functions import FUNCTION_COUNTS
6+
from functions import FUNCTION_COUNTS # type: ignore[import-not-found]
77

88
# Setup the OpenAI client to use either Azure OpenAI or OpenAI API
99
load_dotenv()
1010
API_HOST = os.getenv("API_HOST")
1111

12+
client: openai.OpenAI | openai.AzureOpenAI
13+
1214
if API_HOST == "azure":
15+
16+
if (azure_openai_version := os.getenv("AZURE_OPENAI_VERSION")) is None:
17+
raise ValueError("Missing Azure OpenAI version")
18+
if (azure_openai_endpoint := os.getenv("AZURE_OPENAI_ENDPOINT")) is None:
19+
raise ValueError("Missing Azure OpenAI endpoint")
20+
if (azure_openai_deployment := os.getenv("AZURE_OPENAI_DEPLOYMENT")) is None:
21+
raise ValueError("Missing Azure OpenAI deployment")
22+
1323
token_provider = azure.identity.get_bearer_token_provider(
1424
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
1525
)
1626
client = openai.AzureOpenAI(
17-
api_version=os.getenv("AZURE_OPENAI_VERSION"),
18-
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
27+
api_version=azure_openai_version,
28+
azure_endpoint=azure_openai_endpoint,
1929
azure_ad_token_provider=token_provider,
2030
)
21-
MODEL_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT")
31+
MODEL_NAME = azure_openai_deployment
2232
else:
23-
client = openai.OpenAI(api_key=os.getenv("OPENAI_KEY"))
24-
MODEL_NAME = os.getenv("OPENAI_MODEL")
33+
if (openai_key := os.getenv("OPENAI_KEY")) is None:
34+
raise ValueError("Missing OpenAI API key")
35+
if (openai_model := os.getenv("OPENAI_MODEL")) is None:
36+
raise ValueError("Missing OpenAI model")
37+
client = openai.OpenAI(api_key=openai_key)
38+
MODEL_NAME = openai_model
39+
2540

2641
# Test the token count for each message
2742
for function_count_pair in FUNCTION_COUNTS:
28-
response = client.chat.completions.create(
43+
response = client.chat.completions.create( # type: ignore[call-overload]
2944
model=MODEL_NAME,
3045
temperature=0.7,
3146
n=1,
@@ -35,6 +50,7 @@
3550
)
3651

3752
print(function_count_pair["tools"])
53+
assert response.usage is not None, "Expected usage to be present"
3854
assert (
3955
response.usage.prompt_tokens == function_count_pair["count"]
4056
), f"Expected {function_count_pair['count']} tokens, got {response.usage.prompt_tokens}"

tests/verify_openai.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,51 @@
33
import azure.identity
44
import openai
55
from dotenv import load_dotenv
6-
from messages import MESSAGE_COUNTS
6+
from messages import MESSAGE_COUNTS # type: ignore[import-not-found]
77

88
# Setup the OpenAI client to use either Azure OpenAI or OpenAI API
99
load_dotenv()
1010
API_HOST = os.getenv("API_HOST")
1111

12+
client: openai.OpenAI | openai.AzureOpenAI
13+
1214
if API_HOST == "azure":
15+
if (azure_openai_version := os.getenv("AZURE_OPENAI_VERSION")) is None:
16+
raise ValueError("Missing Azure OpenAI version")
17+
if (azure_openai_endpoint := os.getenv("AZURE_OPENAI_ENDPOINT")) is None:
18+
raise ValueError("Missing Azure OpenAI endpoint")
19+
if (azure_openai_deployment := os.getenv("AZURE_OPENAI_DEPLOYMENT")) is None:
20+
raise ValueError("Missing Azure OpenAI deployment")
21+
1322
token_provider = azure.identity.get_bearer_token_provider(
1423
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
1524
)
1625
client = openai.AzureOpenAI(
17-
api_version=os.getenv("AZURE_OPENAI_VERSION"),
18-
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
26+
api_version=azure_openai_version,
27+
azure_endpoint=azure_openai_endpoint,
1928
azure_ad_token_provider=token_provider,
2029
)
21-
MODEL_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT")
30+
MODEL_NAME = azure_openai_deployment
2231
else:
23-
client = openai.OpenAI(api_key=os.getenv("OPENAI_KEY"))
24-
MODEL_NAME = os.getenv("OPENAI_MODEL")
32+
if (openai_key := os.getenv("OPENAI_KEY")) is None:
33+
raise ValueError("Missing OpenAI API key")
34+
if (openai_model := os.getenv("OPENAI_MODEL")) is None:
35+
raise ValueError("Missing OpenAI model")
36+
client = openai.OpenAI(api_key=openai_key)
37+
MODEL_NAME = openai_model
2538

2639
# Test the token count for each message
2740
for message_count_pair in MESSAGE_COUNTS:
2841
response = client.chat.completions.create(
2942
model=MODEL_NAME,
3043
temperature=0.7,
3144
n=1,
32-
messages=[message_count_pair["message"]],
45+
messages=[message_count_pair["message"]], # type: ignore[list-item]
3346
)
3447

3548
print(message_count_pair["message"])
3649
expected_tokens = message_count_pair["count"]
50+
assert response.usage is not None, "Expected usage to be present"
3751
assert (
3852
response.usage.prompt_tokens == expected_tokens
3953
), f"Expected {expected_tokens} tokens, got {response.usage.prompt_tokens}"

0 commit comments

Comments
 (0)