Skip to content

Commit 60ffc87

Browse files
committed
[0.5.4] standardise model interface
1 parent 772e272 commit 60ffc87

File tree

15 files changed

+105
-123
lines changed

15 files changed

+105
-123
lines changed

cookbooks/baby_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async def chat_completions(request: Request, data: ChatCompletionRequest):
7777
max_tokens=data.max_tokens,
7878
)
7979
api_resp = tu.generator_to_api_events(
80-
model=model.tune_model_id,
80+
model=model.model_id,
8181
generator=stream_resp,
8282
)
8383
return StreamingResponse(api_resp, media_type="text/event-stream")

docs/changelog.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ minor versions.
77

88
All relevant steps to be taken will be mentioned here.
99

10+
0.5.4
11+
-----
12+
13+
- Standardise ``tuneapi.types.chats.ModelInterface`` to have ``model_id``, ``api_token`` added to the base class.
14+
1015
0.5.3
1116
-----
1217

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
project = "tuneapi"
1414
copyright = "2024, Frello Technologies"
1515
author = "Frello Technologies"
16-
release = "0.5.3"
16+
release = "0.5.4"
1717

1818
# -- General configuration ---------------------------------------------------
1919
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tuneapi"
3-
version = "0.5.3"
3+
version = "0.5.4"
44
description = "Tune AI APIs."
55
authors = ["Frello Technology Private Limited <[email protected]>"]
66
license = "MIT"

tuneapi/__main__.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

tuneapi/apis/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,3 @@
77
from tuneapi.apis.model_groq import Groq
88
from tuneapi.apis.model_mistral import Mistral
99
from tuneapi.apis.model_gemini import Gemini
10-
11-
12-
# other imports
13-
import os
14-
import random
15-
from time import time
16-
from typing import List, Optional

tuneapi/apis/model_anthropic.py

Lines changed: 18 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,75 +10,44 @@
1010
from typing import Optional, Dict, Any, Tuple, List
1111

1212
import tuneapi.utils as tu
13-
from tuneapi.types import Thread, human, Message
13+
import tuneapi.types as tt
1414

1515

16-
class Anthropic:
16+
class Anthropic(tt.ModelInterface):
1717
def __init__(
1818
self,
1919
id: Optional[str] = "claude-3-haiku-20240307",
2020
base_url: str = "https://api.anthropic.com/v1/messages",
2121
):
22-
self.anthropic_model = id
22+
self.model_id = id
2323
self.base_url = base_url
24-
self.anthropic_api_token = tu.ENV.ANTHROPIC_TOKEN("")
24+
self.api_token = tu.ENV.ANTHROPIC_TOKEN("")
2525

2626
def set_api_token(self, token: str) -> None:
27-
self.anthropic_api_token = token
28-
29-
def tool_to_claude_xml(self, tool):
30-
"""
31-
Deprecated: was written when function calling did not exist in Anthropic API.
32-
"""
33-
tool_signature = ""
34-
if len(tool["parameters"]) > 0:
35-
for name, p in tool["parameters"]["properties"].items():
36-
param = f"""<parameter>
37-
<name> {name} </name>
38-
<type> {p['type']} </type>
39-
<description> {p['description']} </description>
40-
"""
41-
if name in tool["parameters"]["required"]:
42-
param += "<required> true </required>\n"
43-
param += "</parameter>"
44-
tool_signature += param + "\n"
45-
tool_signature = tool_signature.strip()
46-
47-
constructed_prompt = (
48-
"<tool_description>\n"
49-
f"<tool_name> {tool['name']} </tool_name>\n"
50-
"<description>\n"
51-
f"{tool['description']}\n"
52-
"</description>\n"
53-
"<parameters>\n"
54-
f"{tool_signature}\n"
55-
"</parameters>\n"
56-
"</tool_description>"
57-
)
58-
return constructed_prompt
27+
self.api_token = token
5928

6029
def _process_input(self, chats, token: Optional[str] = None):
61-
if not token and not self.anthropic_api_token: # type: ignore
30+
if not token and not self.api_token: # type: ignore
6231
raise Exception(
6332
"Please set ANTHROPIC_TOKEN environment variable or pass through function"
6433
)
65-
token = token or self.anthropic_api_token
66-
if isinstance(chats, Thread):
34+
token = token or self.api_token
35+
if isinstance(chats, tt.Thread):
6736
thread = chats
6837
elif isinstance(chats, str):
69-
thread = Thread(human(chats))
38+
thread = tt.Thread(tt.human(chats))
7039
else:
7140
raise Exception("Invalid input")
7241

7342
# create the anthropic style data
7443
system = ""
75-
if thread.chats[0].role == Message.SYSTEM:
44+
if thread.chats[0].role == tt.Message.SYSTEM:
7645
system = thread.chats[0].value
7746

7847
claude_messages = []
7948
prev_tool_id = tu.get_random_string(5)
8049
for m in thread.chats[int(system != "") :]:
81-
if m.role == Message.HUMAN:
50+
if m.role == tt.Message.HUMAN:
8251
msg = {
8352
"role": "user",
8453
"content": [{"type": "text", "text": m.value.strip()}],
@@ -95,12 +64,12 @@ def _process_input(self, chats, token: Optional[str] = None):
9564
},
9665
}
9766
)
98-
elif m.role == Message.GPT:
67+
elif m.role == tt.Message.GPT:
9968
msg = {
10069
"role": "assistant",
10170
"content": [{"type": "text", "text": m.value.strip()}],
10271
}
103-
elif m.role == Message.FUNCTION_CALL:
72+
elif m.role == tt.Message.FUNCTION_CALL:
10473
_m = tu.from_json(m.value) if isinstance(m.value, str) else m.value
10574
msg = {
10675
"role": "assistant",
@@ -113,7 +82,7 @@ def _process_input(self, chats, token: Optional[str] = None):
11382
}
11483
],
11584
}
116-
elif m.role == Message.FUNCTION_RESP:
85+
elif m.role == tt.Message.FUNCTION_RESP:
11786
# _m = tu.from_json(m.value) if isinstance(m.value, str) else m.value
11887
msg = {
11988
"role": "user",
@@ -139,7 +108,7 @@ def _process_input(self, chats, token: Optional[str] = None):
139108

140109
def chat(
141110
self,
142-
chats: Thread | str,
111+
chats: tt.Thread | str,
143112
model: Optional[str] = None,
144113
max_tokens: int = 1024,
145114
temperature: Optional[float] = None,
@@ -170,7 +139,7 @@ def chat(
170139

171140
def stream_chat(
172141
self,
173-
chats: Thread | str,
142+
chats: tt.Thread | str,
174143
model: Optional[str] = None,
175144
max_tokens: int = 1024,
176145
temperature: Optional[float] = None,
@@ -182,14 +151,14 @@ def stream_chat(
182151
) -> Any:
183152

184153
tools = []
185-
if isinstance(chats, Thread):
154+
if isinstance(chats, tt.Thread):
186155
tools = [x.to_dict() for x in chats.tools]
187156
for t in tools:
188157
t["input_schema"] = t.pop("parameters")
189158
headers, system, claude_messages = self._process_input(chats=chats, token=token)
190159

191160
data = {
192-
"model": model or self.anthropic_model,
161+
"model": model or self.model_id,
193162
"max_tokens": max_tokens,
194163
"messages": claude_messages,
195164
"system": system,

tuneapi/apis/model_gemini.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@
1313
import tuneapi.types as tt
1414

1515

16-
class Gemini:
16+
class Gemini(tt.ModelInterface):
1717
def __init__(
1818
self,
1919
id: Optional[str] = "gemini-1.5-pro-latest",
2020
base_url: str = "https://generativelanguage.googleapis.com/v1beta/models/{id}:{rpc}",
2121
):
22-
self._gemeni_model_id = id
22+
self.model_id = id
2323
self.base_url = base_url
24-
self.gemini_token = tu.ENV.GEMINI_TOKEN("")
24+
self.api_token = tu.ENV.GEMINI_TOKEN("")
2525

2626
def set_api_token(self, token: str) -> None:
27-
self.gemini_token = token
27+
self.api_token = token
2828

2929
def _process_input(self, chats, token: Optional[str] = None):
30-
if not token and not self.gemini_token: # type: ignore
30+
if not token and not self.api_token: # type: ignore
3131
raise Exception(
3232
"Gemini API key not found. Please set GEMINI_TOKEN environment variable or pass through function"
3333
)
@@ -98,7 +98,7 @@ def _process_input(self, chats, token: Optional[str] = None):
9898

9999
# create headers
100100
headers = self._process_header()
101-
params = {"key": self.gemini_token}
101+
params = {"key": self.api_token}
102102
return headers, system.strip(), messages, params
103103

104104
def _process_header(self):
@@ -197,7 +197,7 @@ def stream_chat(
197197

198198
response = requests.post(
199199
self.base_url.format(
200-
id=model or self._gemeni_model_id,
200+
id=model or self.model_id,
201201
rpc="streamGenerateContent",
202202
),
203203
headers=headers,
@@ -225,12 +225,14 @@ def stream_chat(
225225
continue
226226
block_lines += line
227227

228-
# is the block done?
229-
if line == "{":
230-
done = False
231-
elif line == "}":
228+
done = False
229+
try:
230+
tu.from_json(block_lines)
232231
done = True
232+
except Exception as e:
233+
pass
233234

235+
# print(f"{block_lines=}")
234236
if done:
235237
part_data = json.loads(block_lines)["candidates"][0]["content"][
236238
"parts"

tuneapi/apis/model_groq.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,25 @@
1212
import tuneapi.types as tt
1313

1414

15-
class Groq:
15+
class Groq(tt.ModelInterface):
1616
def __init__(
1717
self,
1818
id: Optional[str] = "llama3-70b-8192",
1919
base_url: str = "https://api.groq.com/openai/v1/chat/completions",
2020
):
21-
self.groq_model_id = id
21+
self.model_id = id
2222
self.base_url = base_url
23-
self.groq_api_token = tu.ENV.GROQ_TOKEN("")
23+
self.api_token = tu.ENV.GROQ_TOKEN("")
2424

2525
def set_api_token(self, token: str) -> None:
26-
self.groq_api_token = token
26+
self.api_token = token
2727

2828
def _process_input(self, chats, token: Optional[str] = None):
29-
if not token and not self.groq_api_token: # type: ignore
29+
if not token and not self.api_token: # type: ignore
3030
raise Exception(
3131
"Please set GROQ_TOKEN environment variable or pass through function"
3232
)
33-
token = token or self.groq_api_token
33+
token = token or self.api_token
3434
if isinstance(chats, tt.Thread):
3535
thread = chats
3636
elif isinstance(chats, str):
@@ -132,7 +132,7 @@ def stream_chat(
132132
data = {
133133
"temperature": temperature,
134134
"messages": messages,
135-
"model": model or self.groq_model_id,
135+
"model": model or self.model_id,
136136
"stream": True,
137137
"max_tokens": max_tokens,
138138
"tools": tools,

tuneapi/apis/model_mistral.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,25 @@
1414
from tuneapi.types import Thread, human, Message
1515

1616

17-
class Mistral:
17+
class Mistral(tt.ModelInterface):
1818
def __init__(
1919
self,
2020
id: Optional[str] = "mistral-small-latest",
2121
base_url: str = "https://api.mistral.ai/v1/chat/completions",
2222
):
23-
self.mistral_model_id = id
23+
self.model_id = id
2424
self.base_url = base_url
25-
self.mistral_api_token = ENV.MISTRAL_TOKEN("")
25+
self.api_token = ENV.MISTRAL_TOKEN("")
2626

2727
def set_api_token(self, token: str) -> None:
28-
self.mistral_api_token = token
28+
self.api_token = token
2929

3030
def _process_input(self, chats, token: Optional[str] = None):
31-
if not token and not self.mistral_api_token: # type: ignore
31+
if not token and not self.api_token: # type: ignore
3232
raise Exception(
3333
"Please set MISTRAL_TOKEN environment variable or pass through function"
3434
)
35-
token = token or self.mistral_api_token
35+
token = token or self.api_token
3636

3737
if isinstance(chats, tt.Thread):
3838
thread = chats
@@ -134,7 +134,7 @@ def stream_chat(
134134
headers, messages = self._process_input(chats, token)
135135
data = {
136136
"messages": messages,
137-
"model": model or self.mistral_model_id,
137+
"model": model or self.model_id,
138138
"stream": True,
139139
"max_tokens": max_tokens,
140140
"tools": tools,

0 commit comments

Comments
 (0)