Skip to content

Commit ac3af6b

Browse files
committed
[0.6.3] adding structured generation to Gemini API
1 parent 97009ad commit ac3af6b

File tree

12 files changed

+106
-29
lines changed

12 files changed

+106
-29
lines changed

docs/changelog.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,22 @@ minor versions.
77

88
All relevant steps to be taken will be mentioned here.
99

10+
0.6.3
11+
-----
12+
13+
- ``<model>.distributed_chat`` now takes in args that are passed to the ``post_logic``.
14+
15+
16+
0.6.2
17+
-----
18+
19+
- New set of utils in ``tuneapi.utils`` called ``prompt`` to help with the basics of prompting.
20+
21+
0.6.1
22+
-----
23+
24+
- Package now uses ``fire==0.7.0``
25+
1026
0.6.0
1127
-----
1228

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tuneapi"
3-
version = "0.6.0"
3+
version = "0.6.3"
44
description = "Tune AI APIs."
55
authors = ["Frello Technology Private Limited <[email protected]>"]
66
license = "MIT"
@@ -9,7 +9,7 @@ repository = "https://github.com/NimbleBoxAI/tuneapi"
99

1010
[tool.poetry.dependencies]
1111
python = "^3.10"
12-
fire = "0.5.0"
12+
fire = "0.7.0"
1313
requests = "^2.31.0"
1414
cloudpickle = "3.0.0"
1515
cryptography = ">=42.0.5"

tuneapi/apis/model_anthropic.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
# Copyright © 2024- Frello Technology Private Limited
66

7-
import re
87
import json
98
import requests
109
from typing import Optional, Dict, Any, Tuple, List
@@ -244,6 +243,7 @@ def distributed_chat(
244243
max_threads: int = 10,
245244
retry: int = 3,
246245
pbar=True,
246+
**kwargs,
247247
):
248248
return distributed_chat(
249249
self,
@@ -252,16 +252,5 @@ def distributed_chat(
252252
max_threads=max_threads,
253253
retry=retry,
254254
pbar=pbar,
255+
**kwargs,
255256
)
256-
257-
258-
# helper methods
259-
260-
261-
def get_section(tag: str, out: str) -> Optional[str]:
262-
pattern = re.compile("<" + tag + ">(.*?)</" + tag + ">", re.DOTALL)
263-
match = pattern.search(out)
264-
if match:
265-
content = match.group(1)
266-
return content
267-
return None

tuneapi/apis/model_gemini.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def chat(
114114
self,
115115
chats: tt.Thread | str,
116116
model: Optional[str] = None,
117-
max_tokens: int = 1024,
117+
max_tokens: int = 4096,
118118
temperature: float = 1,
119119
token: Optional[str] = None,
120120
timeout=None,
@@ -150,7 +150,7 @@ def stream_chat(
150150
self,
151151
chats: tt.Thread | str,
152152
model: Optional[str] = None,
153-
max_tokens: int = 1024,
153+
max_tokens: int = 4096,
154154
temperature: float = 1,
155155
token: Optional[str] = None,
156156
timeout=(5, 60),
@@ -166,18 +166,12 @@ def stream_chat(
166166
extra_headers = extra_headers or self.extra_headers
167167
if extra_headers:
168168
headers.update(extra_headers)
169+
169170
data = {
170171
"systemInstruction": {
171172
"parts": [{"text": system}],
172173
},
173174
"contents": messages,
174-
"generationConfig": {
175-
"temperature": temperature,
176-
"topK": 0,
177-
"topP": 0.95,
178-
"maxOutputTokens": max_tokens,
179-
"stopSequences": [],
180-
},
181175
"safetySettings": [
182176
{
183177
"category": "HARM_CATEGORY_HARASSMENT",
@@ -197,6 +191,22 @@ def stream_chat(
197191
},
198192
],
199193
}
194+
195+
generation_config = {
196+
"temperature": temperature,
197+
"maxOutputTokens": max_tokens,
198+
"stopSequences": [],
199+
}
200+
201+
if chats.gen_schema:
202+
generation_config.update(
203+
{
204+
"response_mime_type": "application/json",
205+
"response_schema": chats.gen_schema,
206+
}
207+
)
208+
data["generationConfig"] = generation_config
209+
200210
if tools:
201211
data["tool_config"] = {
202212
"function_calling_config": {
@@ -285,6 +295,7 @@ def distributed_chat(
285295
max_threads: int = 10,
286296
retry: int = 3,
287297
pbar=True,
298+
**kwargs,
288299
):
289300
return distributed_chat(
290301
self,
@@ -293,4 +304,5 @@ def distributed_chat(
293304
max_threads=max_threads,
294305
retry=retry,
295306
pbar=pbar,
307+
**kwargs,
296308
)

tuneapi/apis/model_groq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def distributed_chat(
199199
max_threads: int = 10,
200200
retry: int = 3,
201201
pbar=True,
202+
**kwargs,
202203
):
203204
return distributed_chat(
204205
self,
@@ -207,4 +208,5 @@ def distributed_chat(
207208
max_threads=max_threads,
208209
retry=retry,
209210
pbar=pbar,
211+
**kwargs,
210212
)

tuneapi/apis/model_mistral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def distributed_chat(
201201
max_threads: int = 10,
202202
retry: int = 3,
203203
pbar=True,
204+
**kwargs,
204205
):
205206
return distributed_chat(
206207
self,
@@ -209,4 +210,5 @@ def distributed_chat(
209210
max_threads=max_threads,
210211
retry=retry,
211212
pbar=pbar,
213+
**kwargs,
212214
)

tuneapi/apis/model_openai.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def stream_chat(
130130
extra_headers: Optional[Dict[str, str]] = None,
131131
debug: bool = False,
132132
raw: bool = False,
133+
**kwargs,
133134
):
134135
headers, messages = self._process_input(chats, token)
135136
extra_headers = extra_headers or self.extra_headers
@@ -148,6 +149,8 @@ def stream_chat(
148149
{"type": "function", "function": x.to_dict()} for x in chats.tools
149150
]
150151
data["parallel_tool_calls"] = parallel_tool_calls
152+
if kwargs:
153+
data.update(kwargs)
151154
if debug:
152155
fp = "sample_oai.json"
153156
print("Saving at path " + fp)
@@ -198,6 +201,7 @@ def distributed_chat(
198201
max_threads: int = 10,
199202
retry: int = 3,
200203
pbar=True,
204+
**kwargs,
201205
):
202206
return distributed_chat(
203207
self,
@@ -206,6 +210,7 @@ def distributed_chat(
206210
max_threads=max_threads,
207211
retry=retry,
208212
pbar=pbar,
213+
**kwargs,
209214
)
210215

211216
def embedding(

tuneapi/apis/model_tune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def distributed_chat(
226226
max_threads: int = 10,
227227
retry: int = 3,
228228
pbar=True,
229+
**kwargs,
229230
):
230231
return distributed_chat(
231232
self,
@@ -234,4 +235,5 @@ def distributed_chat(
234235
max_threads=max_threads,
235236
retry=retry,
236237
pbar=pbar,
238+
**kwargs,
237239
)

tuneapi/apis/turbo.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import queue
44
import threading
55
from tqdm import trange
6-
from typing import List, Optional
6+
from typing import List, Optional, Dict
77
from dataclasses import dataclass
88

99
from tuneapi.types import Thread, ModelInterface, human, system
@@ -16,6 +16,7 @@ def distributed_chat(
1616
max_threads: int = 10,
1717
retry: int = 3,
1818
pbar=True,
19+
**kwargs,
1920
):
2021
"""
2122
Distributes multiple chat prompts across a thread pool for parallel processing.
@@ -78,8 +79,7 @@ def worker():
7879
break
7980

8081
try:
81-
print(">")
82-
out = task.model.chat(task.prompt)
82+
out = task.model.chat(chat=task.prompt, **task.kwargs)
8383
if post_logic:
8484
out = post_logic(out)
8585
result_channel.put(_Result(task.index, out, True))
@@ -94,7 +94,13 @@ def worker():
9494
nm.set_api_token(model.api_token)
9595
# Increment retry count and requeue
9696
task_channel.put(
97-
_Task(task.index, nm, task.prompt, task.retry_count + 1)
97+
_Task(
98+
index=task.index,
99+
model=nm,
100+
prompt=task.prompt,
101+
retry_count=task.retry_count + 1,
102+
kwargs=task.kwargs,
103+
)
98104
)
99105
else:
100106
# If we've exhausted retries, store the error
@@ -122,7 +128,15 @@ def worker():
122128
extra_headers=model.extra_headers,
123129
)
124130
nm.set_api_token(model.api_token)
125-
task_channel.put(_Task(i, nm, p))
131+
task_channel.put(
132+
_Task(
133+
index=i,
134+
model=nm,
135+
prompt=p,
136+
retry_count=0,
137+
kwargs=kwargs,
138+
)
139+
)
126140

127141
# Process results
128142
completed = 0
@@ -160,6 +174,7 @@ class _Task:
160174
model: ModelInterface
161175
prompt: Thread
162176
retry_count: int = 0
177+
kwargs: Optional[Dict] = None
163178

164179

165180
@dataclass

tuneapi/types/chats.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,17 @@ def stream_chat(
346346
):
347347
"""This is the main function to stream chat with the model where each token is iteratively generated"""
348348

349+
def distributed_chat(
350+
self,
351+
prompts: List["Thread"],
352+
post_logic: Optional[callable] = None,
353+
max_threads: int = 10,
354+
retry: int = 3,
355+
pbar=True,
356+
**kwargs,
357+
):
358+
"""This is the main function to chat with the model in a distributed manner"""
359+
349360

350361
########################################################################################################################
351362
#
@@ -372,6 +383,7 @@ def __init__(
372383
id: str = "",
373384
title: str = "",
374385
tools: List[Tool] = [],
386+
gen_schema: Optional[Dict[str, Any]] = None,
375387
**kwargs,
376388
):
377389
self.chats = list(chats)
@@ -380,6 +392,7 @@ def __init__(
380392
self.id = id or "thread_" + str(tu.get_snowflake())
381393
self.title = title
382394
self.tools = tools
395+
self.gen_schema = gen_schema
383396

384397
#
385398
kwargs = {k: v for k, v in sorted(kwargs.items())}
@@ -462,6 +475,7 @@ def to_dict(self, full: bool = False):
462475
"title": self.title,
463476
"id": self.id,
464477
"tools": [x.to_dict() for x in self.tools],
478+
"gen_schema": self.gen_schema,
465479
}
466480
return {
467481
"chats": [x.to_dict() for x in self.chats],
@@ -484,6 +498,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "Thread":
484498
model=data.get("model", ""),
485499
title=data.get("title", ""),
486500
tools=[Tool.from_dict(x) for x in data.get("tools", [])],
501+
gen_schema=data.get("gen_schema", {}),
487502
**data.get("meta", {}),
488503
)
489504

0 commit comments

Comments
 (0)