Skip to content

Commit aa6ab55

Browse files
committed
[0.7.1] add structured generation
1 parent 73823b9 commit aa6ab55

File tree

11 files changed

+374
-50
lines changed

11 files changed

+374
-50
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from tuneapi import tu, tt, ta\n",
10+
"from dataclasses import dataclass\n",
11+
"from pydantic import BaseModel\n",
12+
"from typing import List, Optional, Dict, Any"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": 2,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"class MedicalRecord(BaseModel):\n",
22+
" date: str\n",
23+
" diagnosis: str\n",
24+
" treatment: str\n",
25+
"\n",
26+
"class Dog(BaseModel):\n",
27+
" name: str\n",
28+
" breed: str\n",
29+
" records: Optional[List[MedicalRecord]] = None\n",
30+
"\n",
31+
"class Dogs(BaseModel):\n",
32+
" dogs: List[Dog]\n"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 6,
38+
"metadata": {},
39+
"outputs": [
40+
{
41+
"name": "stdout",
42+
"output_type": "stream",
43+
"text": [
44+
"Dog: Buddy, Breed: Golden Retriever\n",
45+
" Date: 2023-10-26, Diagnosis: Mild ear infection, Treatment: Ear drops\n",
46+
"\n",
47+
"Dog: Luna, Breed: Beagle\n",
48+
" Date: 2023-10-25, Diagnosis: Routine check-up, Treatment: No treatment needed\n",
49+
" Date: 2023-10-28, Diagnosis: Upset tummy, Treatment: Bland diet and probiotics\n",
50+
"\n",
51+
"Dog: Rocky, Breed: Terrier Mix\n",
52+
" Date: 2023-10-29, Diagnosis: Cut on paw, Treatment: Cleaned and antibiotic ointment\n",
53+
"\n",
54+
"Dog: Daisy, Breed: Poodle\n",
55+
" No medical records on file.\n",
56+
"\n"
57+
]
58+
}
59+
],
60+
"source": [
61+
"# As of this moment we have tested it with the following LLMs:\n",
62+
"\n",
63+
"# model = ta.Openai()\n",
64+
"model = ta.Gemini()\n",
65+
"\n",
66+
"out: Dogs = model.chat(tt.Thread(\n",
67+
" tt.human(\"\"\"\n",
68+
" At the Sunny Paws Animal Clinic, we keep detailed records of all our furry patients. Today, we saw a few dogs.\n",
69+
" There was 'Buddy,' a golden retriever, who visited on '2023-10-26' and was diagnosed with a 'mild ear infection,'\n",
70+
" which we treated with 'ear drops.' Then, there was 'Luna,' a playful beagle, who came in on '2023-10-25' for a\n",
71+
" 'routine check-up,' and no treatment was needed, but we also had her back on '2023-10-28' with a 'upset tummy'\n",
72+
" which we treated with 'bland diet and probiotics.' Finally, a third dog named 'Rocky', a small terrier mix,\n",
73+
" showed up on '2023-10-29' with a small 'cut on his paw,' we cleaned it and used an 'antibiotic ointment'. We\n",
74+
" also have 'Daisy,' a fluffy poodle, who doesn't have any medical records yet, thankfully!\n",
75+
" \"\"\"),\n",
76+
" schema=Dogs,\n",
77+
"))\n",
78+
"\n",
79+
"for dog in out.dogs:\n",
80+
" print(f\"Dog: {dog.name}, Breed: {dog.breed}\")\n",
81+
" if dog.records:\n",
82+
" for record in dog.records:\n",
83+
" print(f\" Date: {record.date}, Diagnosis: {record.diagnosis}, Treatment: {record.treatment}\")\n",
84+
" else:\n",
85+
" print(\" No medical records on file.\")\n",
86+
" print()"
87+
]
88+
}
89+
],
90+
"metadata": {
91+
"kernelspec": {
92+
"display_name": "Python 3",
93+
"language": "python",
94+
"name": "python3"
95+
},
96+
"language_info": {
97+
"codemirror_mode": {
98+
"name": "ipython",
99+
"version": 3
100+
},
101+
"file_extension": ".py",
102+
"mimetype": "text/x-python",
103+
"name": "python",
104+
"nbconvert_exporter": "python",
105+
"pygments_lexer": "ipython3",
106+
"version": "3.12.7"
107+
}
108+
},
109+
"nbformat": 4,
110+
"nbformat_minor": 2
111+
}

docs/changelog.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,56 @@ minor versions.
77

88
All relevant steps to be taken will be mentioned here.
99

10+
0.7.1
11+
-----
12+
13+
- Add structured genration support for Gemini and OpenAI APIs. You can jsut pass ``schema`` to ``Thread``. ``model.chat``
14+
will take care of it automatically. Here's an example:
15+
16+
.. code-block:: python
17+
18+
from tuneapi import tt, ta
19+
from pydantic import BaseModel
20+
from typing import List, Optional, Dict, Any
21+
22+
class MedicalRecord(BaseModel):
23+
date: str
24+
diagnosis: str
25+
treatment: str
26+
27+
class Dog(BaseModel):
28+
name: str
29+
breed: str
30+
records: Optional[List[MedicalRecord]] = None
31+
32+
class Dogs(BaseModel):
33+
dogs: List[Dog]
34+
35+
model = ta.Gemini()
36+
out: Dogs = model.chat(tt.Thread(
37+
tt.human("""
38+
At the Sunny Paws Animal Clinic, we keep detailed records of all our furry patients. Today, we saw a few dogs.
39+
There was 'Buddy,' a golden retriever, who visited on '2023-10-26' and was diagnosed with a 'mild ear infection,'
40+
which we treated with 'ear drops.' Then, there was 'Luna,' a playful beagle, who came in on '2023-10-25' for a
41+
'routine check-up,' and no treatment was needed, but we also had her back on '2023-10-28' with a 'upset tummy'
42+
which we treated with 'bland diet and probiotics.' Finally, a third dog named 'Rocky', a small terrier mix,
43+
showed up on '2023-10-29' with a small 'cut on his paw,' we cleaned it and used an 'antibiotic ointment'. We
44+
also have 'Daisy,' a fluffy poodle, who doesn't have any medical records yet, thankfully!
45+
"""),
46+
schema=Dogs,
47+
))
48+
49+
for dog in out.dogs:
50+
print(f"Dog: {dog.name}, Breed: {dog.breed}")
51+
if dog.records:
52+
for record in dog.records:
53+
print(f" Date: {record.date}, Diagnosis: {record.diagnosis}, Treatment: {record.treatment}")
54+
else:
55+
print(" No medical records on file.")
56+
print()
57+
58+
- Add ``pydantic`` as a dependency in the package.
59+
1060
0.7.0
1161
-----
1262

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
project = "tuneapi"
1414
copyright = "2024, Frello Technologies"
1515
author = "Frello Technologies"
16-
release = "0.5.13"
16+
release = "0.7.1"
1717

1818
# -- General configuration ---------------------------------------------------
1919
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tuneapi"
3-
version = "0.7.0"
3+
version = "0.7.1"
44
description = "Tune AI APIs."
55
authors = ["Frello Technology Private Limited <[email protected]>"]
66
license = "MIT"
@@ -18,6 +18,7 @@ snowflake_id = "1.0.2"
1818
nutree = "0.8.0"
1919
pillow = "^10.2.0"
2020
httpx = "^0.28.1"
21+
pydantic = "^2.6.4"
2122
protobuf = { version = "^5.27.3", optional = true }
2223
boto3 = { version = "1.29.6", optional = true }
2324

tuneapi/apis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
from tuneapi.apis.model_groq import Groq
88
from tuneapi.apis.model_mistral import Mistral
99
from tuneapi.apis.model_gemini import Gemini
10-
from tuneapi.apis.turbo import distributed_chat
10+
from tuneapi.apis.turbo import distributed_chat, distributed_chat_async

tuneapi/apis/model_gemini.py

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
import httpx
99
import requests
10-
from typing import Optional, Any, Dict, List
10+
from pydantic import BaseModel
11+
from typing import get_args, get_origin, List, Optional, Dict, Any, Union
1112

1213
import tuneapi.utils as tu
1314
import tuneapi.types as tt
@@ -110,6 +111,106 @@ def _process_header(self):
110111
"Content-Type": "application/json",
111112
}
112113

114+
@staticmethod
115+
def get_structured_schema(model: type[BaseModel]) -> Dict[str, Any]:
116+
"""
117+
Converts a Pydantic BaseModel to a JSON schema compatible with Gemini API,
118+
including `anyOf` for optional or union types and handling nested structures correctly.
119+
120+
Args:
121+
model: The Pydantic BaseModel class to convert.
122+
123+
Returns:
124+
A dictionary representing the JSON schema.
125+
"""
126+
127+
def _process_field(
128+
field_name: str, field_type: Any, field_description: str = None
129+
) -> dict:
130+
"""Helper function to process a single field."""
131+
schema = {}
132+
origin = get_origin(field_type)
133+
args = get_args(field_type)
134+
135+
if origin is list:
136+
schema["type"] = "array"
137+
if args:
138+
item_schema = _process_field_type(args[0])
139+
schema["items"] = item_schema
140+
if "type" not in item_schema and "anyOf" not in item_schema:
141+
schema["items"]["type"] = "object" # default item type for list
142+
else:
143+
schema["items"] = {}
144+
elif origin is Optional:
145+
if args:
146+
inner_schema = _process_field_type(args[0])
147+
schema["anyOf"] = [inner_schema, {"type": "null"}]
148+
else:
149+
schema = {"type": "null"}
150+
elif origin is dict:
151+
schema["type"] = "object"
152+
if len(args) == 2:
153+
schema["additionalProperties"] = _process_field_type(args[1])
154+
else:
155+
schema = _process_field_type(field_type)
156+
157+
if field_description:
158+
schema["description"] = field_description
159+
return schema
160+
161+
def _process_field_type(field_type: Any) -> dict:
162+
"""Helper function to process the type of a field."""
163+
164+
origin = get_origin(field_type)
165+
args = get_args(field_type)
166+
167+
if field_type is str:
168+
return {"type": "string"}
169+
elif field_type is int:
170+
return {"type": "integer"}
171+
elif field_type is float:
172+
return {"type": "number"}
173+
elif field_type is bool:
174+
return {"type": "boolean"}
175+
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
176+
return Gemini.get_structured_schema(
177+
field_type
178+
) # Recursive call for nested models
179+
elif origin is list:
180+
schema = {"type": "array"}
181+
if args:
182+
item_schema = _process_field_type(args[0])
183+
schema["items"] = item_schema
184+
if "type" not in item_schema and "anyOf" not in item_schema:
185+
schema["items"]["type"] = "object"
186+
return schema
187+
elif origin is Optional:
188+
return _process_field_type(args[0])
189+
elif origin is dict:
190+
schema = {"type": "object"}
191+
if len(args) == 2:
192+
schema["additionalProperties"] = _process_field_type(args[1])
193+
return schema
194+
elif origin is Union:
195+
return _process_field_type(args[0])
196+
else:
197+
return {"type": "string"} # default any object to string
198+
199+
schema = {"type": "object", "properties": {}, "required": []}
200+
201+
for field_name, field in model.model_fields.items():
202+
field_description = field.description
203+
if field.is_required():
204+
schema["required"].append(field_name)
205+
206+
schema["properties"][field_name] = _process_field(
207+
field_name, field.annotation, field_description
208+
)
209+
210+
if model.__doc__:
211+
schema["description"] = model.__doc__.strip()
212+
return schema
213+
113214
def chat(
114215
self,
115216
chats: tt.Thread | str,
@@ -139,11 +240,13 @@ def chat(
139240
output = x
140241
else:
141242
output += x
142-
except Exception as e:
143-
if not x:
144-
raise e
145-
else:
146-
raise ValueError(x)
243+
except requests.HTTPError as e:
244+
print(e.response.text)
245+
raise e
246+
247+
if chats.schema:
248+
output = chats.schema(**tu.from_json(output))
249+
return output
147250
return output
148251

149252
def stream_chat(
@@ -198,11 +301,11 @@ def stream_chat(
198301
"stopSequences": [],
199302
}
200303

201-
if chats.gen_schema:
304+
if chats.schema:
202305
generation_config.update(
203306
{
204307
"response_mime_type": "application/json",
205-
"response_schema": chats.gen_schema,
308+
"response_schema": self.get_structured_schema(chats.schema),
206309
}
207310
)
208311
data["generationConfig"] = generation_config
@@ -376,11 +479,11 @@ async def stream_chat_async(
376479
"stopSequences": [],
377480
}
378481

379-
if chats.gen_schema:
482+
if chats.schema:
380483
generation_config.update(
381484
{
382485
"response_mime_type": "application/json",
383-
"response_schema": chats.gen_schema,
486+
"response_schema": chats.schema,
384487
}
385488
)
386489
data["generationConfig"] = generation_config

tuneapi/apis/model_mistral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tuneapi.utils as tu
1212
import tuneapi.types as tt
1313
from tuneapi.apis.turbo import distributed_chat
14+
from tuneapi.apis.model_openai import Openai as _Openai
1415

1516

1617
class Mistral(tt.ModelInterface):

0 commit comments

Comments
 (0)