Skip to content

Commit 06e1f87

Browse files
committed
[8.0.2] Usage + typo in versioning
1 parent cdf0d61 commit 06e1f87

File tree

12 files changed

+312
-93
lines changed

12 files changed

+312
-93
lines changed

docs/changelog.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ minor versions.
77

88
All relevant steps to be taken will be mentioned here.
99

10+
8.0.2
11+
-----
12+
13+
- Added usage tracking for OpenAI and Anthropic
14+
15+
8.0.1
16+
-----
17+
18+
- Typo so now we are in 8.x.x series
19+
- Fix bug in structured generation for ``Openai``.
1020

1121
0.8.0
1222
-----

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
project = "tuneapi"
1414
copyright = "2024-2025, Frello Technologies"
1515
author = "Frello Technologies"
16-
release = "8.0.0"
16+
release = "8.0.2"
1717

1818
# -- General configuration ---------------------------------------------------
1919
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

docs/index.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ paste the following code snippet in the prompt to generate the code for LLM API
5353
5454
class MathTest(BaseModel):
5555
title: str = Field(..., description="Title of the test")
56-
problems: List[MathProblem] = Field(..., description="List of math problems")
56+
problems: List[MathProblem] = ... # only list of other BaseModel is allowed
5757
5858
# define a thread which is a collection of messages
5959
thread = tt.Thread(
@@ -65,6 +65,18 @@ paste the following code snippet in the prompt to generate the code for LLM API
6565
resp: MathTest = model.chat(thread)
6666
```
6767
68+
Structured generation
69+
---------------------
70+
71+
.. epigraph::
72+
73+
Types and Logic is the two parts of programming.
74+
75+
76+
With structured generation you can get ``pydantic.BaseModel`` objects from ``tt.ModelInterface.chat`` and
77+
``tt.ModelInterface.chat_async`` methods. The currect limitation is that keys cannot have another ``BaseModel`` as value
78+
only ``List[BaseModel]`` is allowed.
79+
6880

6981
.. toctree::
7082
:maxdepth: 2

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tuneapi"
3-
version = "8.0.0"
3+
version = "8.0.2"
44
description = "Tune AI APIs."
55
authors = ["Frello Technology Private Limited <[email protected]>"]
66
license = "MIT"

tuneapi/apis/model_anthropic.py

Lines changed: 106 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import httpx
88
import requests
9+
from copy import deepcopy
910
from typing import Optional, Dict, Any, List
1011

1112
import tuneapi.utils as tu
@@ -18,11 +19,12 @@ def __init__(
1819
self,
1920
id: Optional[str] = "claude-3-haiku-20240307",
2021
base_url: str = "https://api.anthropic.com/v1/messages",
22+
api_token: Optional[str] = None,
2123
extra_headers: Optional[Dict[str, str]] = None,
2224
):
2325
self.model_id = id
2426
self.base_url = base_url
25-
self.api_token = tu.ENV.ANTHROPIC_TOKEN("")
27+
self.api_token = api_token or tu.ENV.ANTHROPIC_TOKEN("")
2628
self.extra_headers = extra_headers
2729

2830
def set_api_token(self, token: str) -> None:
@@ -60,13 +62,17 @@ def _process_input(
6062
prev_tool_id = tu.get_random_string(5)
6163
for m in thread.chats[int(system != "") :]:
6264
if m.role == tt.Message.HUMAN:
63-
msg = {
64-
"role": "user",
65-
"content": [{"type": "text", "text": m.value.strip()}],
66-
}
65+
if isinstance(m.value, str):
66+
content = [{"type": "text", "text": m.value}]
67+
elif isinstance(m.value, list):
68+
content = deepcopy(m.value)
69+
else:
70+
raise Exception(
71+
f"Unknown message type. Got: '{type(m.value)}', expected 'List[Dict[str, Any]]' or 'str'"
72+
)
6773
if m.images:
6874
for i in m.images:
69-
msg["content"].append(
75+
content.append(
7076
{
7177
"type": "image",
7278
"source": {
@@ -76,14 +82,19 @@ def _process_input(
7682
},
7783
}
7884
)
85+
msg = {"role": "user", "content": content}
7986
elif m.role == tt.Message.GPT:
80-
msg = {
81-
"role": "assistant",
82-
"content": [{"type": "text", "text": m.value.strip()}],
83-
}
87+
if isinstance(m.value, str):
88+
content = [{"type": "text", "text": m.value}]
89+
elif isinstance(m.value, list):
90+
content = deepcopy(m.value)
91+
else:
92+
raise Exception(
93+
f"Unknown message type. Got: '{type(m.value)}', expected 'List[Dict[str, Any]]' or 'str'"
94+
)
8495
if m.images:
8596
for i in m.images:
86-
msg["content"].append(
97+
content.append(
8798
{
8899
"type": "image",
89100
"source": {
@@ -93,6 +104,7 @@ def _process_input(
93104
},
94105
}
95106
)
107+
msg = {"role": "assistant", "content": content}
96108
elif m.role == tt.Message.FUNCTION_CALL:
97109
_m = tu.from_json(m.value) if isinstance(m.value, str) else m.value
98110
msg = {
@@ -159,49 +171,64 @@ def _process_input(
159171

160172
return headers, data
161173

162-
def _process_output(self, raw: bool, lines_fn: callable):
174+
def _process_output(self, raw: bool, lines_fn: callable, yield_usage: bool):
163175
fn_call = None
176+
usage_dict = {}
164177
for line in lines_fn():
165178
if isinstance(line, bytes):
166179
line = line.decode().strip()
167180
if not line or not "data:" in line:
168181
continue
169182

170-
try:
171-
# print(line)
172-
resp = tu.from_json(line.replace("data:", "").strip())
173-
if resp["type"] == "content_block_start":
174-
if resp["content_block"]["type"] == "tool_use":
175-
fn_call = {
176-
"name": resp["content_block"]["name"],
177-
"arguments": "",
178-
}
179-
elif resp["type"] == "content_block_delta":
180-
delta = resp["delta"]
181-
delta_type = delta["type"]
182-
if delta_type == "text_delta":
183-
if raw:
184-
yield b"data: " + tu.to_json(
185-
{
186-
"object": delta_type,
187-
"choices": [{"delta": {"content": delta["text"]}}],
188-
},
189-
tight=True,
190-
).encode()
191-
yield b"" # uncomment this line if you want 1:1 with OpenAI
192-
else:
193-
yield delta["text"]
194-
elif delta_type == "input_json_delta":
195-
fn_call["arguments"] += delta["partial_json"]
196-
elif resp["type"] == "content_block_stop":
197-
if fn_call:
198-
fn_call["arguments"] = tu.from_json(
199-
fn_call["arguments"] or "{}"
200-
)
201-
yield fn_call
202-
fn_call = None
203-
except:
204-
break
183+
resp = tu.from_json(line.replace("data:", "").strip())
184+
if resp["type"] == "message_start":
185+
usage = resp["message"]["usage"]
186+
usage_dict.update(usage)
187+
elif resp["type"] == "content_block_start":
188+
if resp["content_block"]["type"] == "tool_use":
189+
fn_call = {
190+
"name": resp["content_block"]["name"],
191+
"arguments": "",
192+
}
193+
elif resp["type"] == "content_block_delta":
194+
delta = resp["delta"]
195+
delta_type = delta["type"]
196+
if delta_type == "text_delta":
197+
if raw:
198+
yield b"data: " + tu.to_json(
199+
{
200+
"object": delta_type,
201+
"choices": [{"delta": {"content": delta["text"]}}],
202+
},
203+
tight=True,
204+
).encode()
205+
yield b"" # uncomment this line if you want 1:1 with OpenAI
206+
else:
207+
yield delta["text"]
208+
elif delta_type == "input_json_delta":
209+
fn_call["arguments"] += delta["partial_json"]
210+
elif resp["type"] == "content_block_stop":
211+
if fn_call:
212+
fn_call["arguments"] = tu.from_json(fn_call["arguments"] or "{}")
213+
yield fn_call
214+
fn_call = None
215+
elif resp["type"] == "message_delta":
216+
usage_dict["output_tokens"] += resp["usage"]["output_tokens"]
217+
cached_tokens = usage_dict.get(
218+
"cache_read_input_tokens", 0
219+
) or usage_dict.get("cache_creation_input_tokens", 0)
220+
usage_obj = tt.Usage(
221+
input_tokens=usage_dict.pop("input_tokens"),
222+
output_tokens=usage_dict.pop("output_tokens"),
223+
cached_tokens=cached_tokens,
224+
**usage_dict,
225+
)
226+
if yield_usage:
227+
if raw:
228+
yield b"data: " + usage_obj.to_json(tight=True).encode()
229+
yield b"" # uncomment this line if you want 1:1 with OpenAI
230+
else:
231+
yield usage_obj
205232

206233
# Interaction methods
207234

@@ -212,30 +239,35 @@ def chat(
212239
max_tokens: int = 1024,
213240
temperature: Optional[float] = None,
214241
token: Optional[str] = None,
215-
return_message: bool = False,
242+
usage: bool = False,
216243
extra_headers: Optional[Dict[str, str]] = None,
217244
**kwargs,
218245
):
219246
output = ""
247+
usage_obj = None
220248
fn_call = None
221249
for i in self.stream_chat(
222250
chats=chats,
223251
model=model,
224252
max_tokens=max_tokens,
225253
temperature=temperature,
226254
token=token,
255+
usage=usage,
227256
extra_headers=extra_headers,
228257
raw=False,
229258
**kwargs,
230259
):
231260
if isinstance(i, dict):
232261
fn_call = i.copy()
262+
elif isinstance(i, tt.Usage):
263+
usage_obj = i
233264
else:
234265
output += i
235-
if return_message:
236-
return output, fn_call
237266
if fn_call:
238-
return fn_call
267+
output = fn_call
268+
269+
if usage:
270+
return output, usage_obj
239271
return output
240272

241273
def stream_chat(
@@ -246,6 +278,7 @@ def stream_chat(
246278
temperature: Optional[float] = None,
247279
token: Optional[str] = None,
248280
debug: bool = False,
281+
usage: bool = False,
249282
extra_headers: Optional[Dict[str, str]] = None,
250283
timeout=(5, 30),
251284
raw: bool = False,
@@ -262,19 +295,23 @@ def stream_chat(
262295
extra_headers=extra_headers,
263296
**kwargs,
264297
)
265-
r = requests.post(
266-
self.base_url,
267-
headers=headers,
268-
json=data,
269-
timeout=timeout,
270-
)
271298
try:
299+
r = requests.post(
300+
self.base_url,
301+
headers=headers,
302+
json=data,
303+
timeout=timeout,
304+
)
272305
r.raise_for_status()
273306
except Exception as e:
274307
yield r.text
275308
raise e
276309

277-
yield from self._process_output(raw=raw, lines_fn=r.iter_lines)
310+
yield from self._process_output(
311+
raw=raw,
312+
lines_fn=r.iter_lines,
313+
yield_usage=usage,
314+
)
278315

279316
async def chat_async(
280317
self,
@@ -283,30 +320,35 @@ async def chat_async(
283320
max_tokens: int = 1024,
284321
temperature: Optional[float] = None,
285322
token: Optional[str] = None,
286-
return_message: bool = False,
323+
usage: bool = False,
287324
extra_headers: Optional[Dict[str, str]] = None,
288325
**kwargs,
289326
):
290327
output = ""
328+
usage_obj = None
291329
fn_call = None
292330
async for i in self.stream_chat_async(
293331
chats=chats,
294332
model=model,
295333
max_tokens=max_tokens,
296334
temperature=temperature,
297335
token=token,
336+
usage=usage,
298337
extra_headers=extra_headers,
299338
raw=False,
300339
**kwargs,
301340
):
302341
if isinstance(i, dict):
303342
fn_call = i.copy()
343+
elif isinstance(i, tt.Usage):
344+
usage_obj = i
304345
else:
305346
output += i
306-
if return_message:
307-
return output, fn_call
347+
308348
if fn_call:
309-
return fn_call
349+
output = fn_call
350+
if usage:
351+
return output, usage_obj
310352
return output
311353

312354
async def stream_chat_async(
@@ -317,6 +359,7 @@ async def stream_chat_async(
317359
temperature: Optional[float] = None,
318360
token: Optional[str] = None,
319361
debug: bool = False,
362+
usage: bool = False,
320363
extra_headers: Optional[Dict[str, str]] = None,
321364
timeout=(5, 30),
322365
raw: bool = False,
@@ -351,6 +394,7 @@ async def stream_chat_async(
351394
for x in self._process_output(
352395
raw=raw,
353396
lines_fn=chunk.decode("utf-8").splitlines,
397+
yield_usage=usage,
354398
):
355399
yield x
356400

0 commit comments

Comments
 (0)