Skip to content

Commit 536d948

Browse files
committed
Fix constrained structured output for openai and anthropic
1 parent 0eafe40 commit 536d948

File tree

2 files changed

+77
-11
lines changed

2 files changed

+77
-11
lines changed

synalinks/src/language_models/language_model.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ class LanguageModel(SynalinksSaveable):
1919
structures in language. Language models can perform various tasks such as text
2020
generation, translation, summarization, and answering questions.
2121
22-
Many providers are available like OpenAI, Anthropic, Groq, or Ollama.
22+
We support providers that implement *constrained structured output*
23+
like OpenAI, Ollama or Mistral. In addition we support providers that otherwise
24+
allow to constrain the use of a specific tool like Groq or Anthropic.
2325
2426
For the complete list of models, please refer to the providers documentation.
2527
@@ -31,9 +33,24 @@ class LanguageModel(SynalinksSaveable):
3133
3234
os.environ["OPENAI_API_KEY"] = "your-api-key"
3335
34-
language_model = synalinks.LanguageModel(model="openai/gpt-4o-mini")
36+
language_model = synalinks.LanguageModel(
37+
model="openai/gpt-4o-mini",
38+
)
3539
```
3640
41+
**Using Groq models**
42+
43+
```python
44+
import synalinks
45+
import os
46+
47+
os.environ["GROQ_API_KEY"] = "your-api-key"
48+
49+
language_model = synalinks.LanguageModel(
50+
model="groq/llama3-8b-8192",
51+
)
52+
```
53+
3754
**Using Anthropic models**
3855
3956
```python
@@ -46,17 +63,17 @@ class LanguageModel(SynalinksSaveable):
4663
model="anthropic/claude-3-sonnet-20240229",
4764
)
4865
```
49-
50-
**Using Groq models**
66+
67+
**Using Mistral models**
5168
5269
```python
5370
import synalinks
5471
import os
5572
56-
os.environ["GROQ_API_KEY"] = "your-api-key"
73+
os.environ["MISTRAL_API_KEY"] = "your-api-key"
5774
5875
language_model = synalinks.LanguageModel(
59-
model="groq/llama3-8b-8192",
76+
model="mistral/codestral-latest",
6077
)
6178
```
6279
@@ -111,7 +128,7 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
111128
json_instance = {}
112129
if schema:
113130
if self.model.startswith("groq"):
114-
# Use a tool created on the fly for Groq
131+
# Use a tool created on the fly for groq
115132
kwargs.update(
116133
{
117134
"tools": [
@@ -130,15 +147,60 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
130147
},
131148
}
132149
)
133-
else:
150+
elif self.model.startswith("anthropic"):
151+
# Use a tool created on the fly for anthropic
152+
kwargs.update(
153+
{
154+
"tools": [
155+
{
156+
"name": "structured_output",
157+
"description": "Generate a valid JSON output",
158+
"input_schema": {
159+
"type": "object",
160+
"properties": schema.get("properties"),
161+
"required": schema.get("required"),
162+
}
163+
}
164+
],
165+
"tool_choice": {
166+
"type": "tool",
167+
"name": "structured_output",
168+
}
169+
}
170+
)
171+
elif self.model.startswith("ollama") or self.model.startswith("mistral"):
172+
# Use constrained structured output for ollama/mistral
134173
kwargs.update(
135174
{
136175
"response_format": {
137176
"type": "json_schema",
138-
"json_schema": {"schema": schema},
177+
"json_schema": {
178+
"schema": schema
179+
},
180+
"strict": True,
139181
},
140182
}
141183
)
184+
elif self.model.startwith("openai"):
185+
# Use constrained structured output for openai
186+
kwargs.update(
187+
{
188+
"response_format": {
189+
"type": "json_schema",
190+
"json_schema": {
191+
"name": "structured_output",
192+
"strict": True,
193+
"schema": schema,
194+
}
195+
}
196+
}
197+
)
198+
else:
199+
provider = self.model.split("/")[0]
200+
raise ValueError(
201+
f"LM provider '{provider}' not supported yet, please ensure that"
202+
" they support constrained structured output and fill an issue."
203+
)
142204

143205
if self.api_base:
144206
kwargs.update(
@@ -165,6 +227,11 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
165227
response_str = response["choices"][0]["message"]["tool_calls"][0][
166228
"function"
167229
]["arguments"]
230+
elif self.model.startswith("anthropic") and schema:
231+
for content_block in response["content"]:
232+
if content_block["type"] == "tool_use":
233+
response_str = json.dumps(content_block["input"])
234+
break
168235
else:
169236
response_str = response["choices"][0]["message"]["content"].strip()
170237
if schema:
@@ -174,7 +241,6 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
174241
return json_instance
175242
except Exception as e:
176243
warnings.warn(str(e))
177-
raise e
178244
return None
179245

180246
def _obj_type(self):

synalinks/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from synalinks.src.api_export import synalinks_export
44

55
# Unique source of truth for the version number.
6-
__version__ = "0.1.3002"
6+
__version__ = "0.1.3003"
77

88

99
@synalinks_export("synalinks.version")

0 commit comments

Comments
 (0)