@@ -19,7 +19,9 @@ class LanguageModel(SynalinksSaveable):
1919 structures in language. Language models can perform various tasks such as text
2020 generation, translation, summarization, and answering questions.
2121
22- Many providers are available like OpenAI, Anthropic, Groq, or Ollama.
22+ We support providers that implement *constrained structured output*
23+ like OpenAI, Ollama or Mistral. In addition we support providers that otherwise
24+ allow to constrain the use of a specific tool like Groq or Anthropic.
2325
2426 For the complete list of models, please refer to the providers documentation.
2527
@@ -31,9 +33,24 @@ class LanguageModel(SynalinksSaveable):
3133
3234 os.environ["OPENAI_API_KEY"] = "your-api-key"
3335
34- language_model = synalinks.LanguageModel(model="openai/gpt-4o-mini")
36+ language_model = synalinks.LanguageModel(
37+ model="openai/gpt-4o-mini",
38+ )
3539 ```
3640
41+ **Using Groq models**
42+
43+ ```python
44+ import synalinks
45+ import os
46+
47+ os.environ["GROQ_API_KEY"] = "your-api-key"
48+
49+ language_model = synalinks.LanguageModel(
50+ model="groq/llama3-8b-8192",
51+ )
52+ ```
53+
3754 **Using Anthropic models**
3855
3956 ```python
@@ -46,17 +63,17 @@ class LanguageModel(SynalinksSaveable):
4663 model="anthropic/claude-3-sonnet-20240229",
4764 )
4865 ```
49-
50- **Using Groq models**
66+
67+ **Using Mistral models**
5168
5269 ```python
5370 import synalinks
5471 import os
5572
56- os.environ["GROQ_API_KEY "] = "your-api-key"
73+ os.environ["MISTRAL_API_KEY "] = "your-api-key"
5774
5875 language_model = synalinks.LanguageModel(
59- model="groq/llama3-8b-8192 ",
76+ model="mistral/codestral-latest ",
6077 )
6178 ```
6279
@@ -111,7 +128,7 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
111128 json_instance = {}
112129 if schema :
113130 if self .model .startswith ("groq" ):
114- # Use a tool created on the fly for Groq
131+ # Use a tool created on the fly for groq
115132 kwargs .update (
116133 {
117134 "tools" : [
@@ -130,15 +147,60 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
130147 },
131148 }
132149 )
133- else :
150+ elif self .model .startswith ("anthropic" ):
151+ # Use a tool created on the fly for anthropic
152+ kwargs .update (
153+ {
154+ "tools" : [
155+ {
156+ "name" : "structured_output" ,
157+ "description" : "Generate a valid JSON output" ,
158+ "input_schema" : {
159+ "type" : "object" ,
160+ "properties" : schema .get ("properties" ),
161+ "required" : schema .get ("required" ),
162+ }
163+ }
164+ ],
165+ "tool_choice" : {
166+ "type" : "tool" ,
167+ "name" : "structured_output" ,
168+ }
169+ }
170+ )
171+ elif self .model .startswith ("ollama" ) or self .model .startswith ("mistral" ):
172+ # Use constrained structured output for ollama/mistral
134173 kwargs .update (
135174 {
136175 "response_format" : {
137176 "type" : "json_schema" ,
138- "json_schema" : {"schema" : schema },
177+ "json_schema" : {
178+ "schema" : schema
179+ },
180+ "strict" : True ,
139181 },
140182 }
141183 )
184+ elif self .model .startwith ("openai" ):
185+ # Use constrained structured output for openai
186+ kwargs .update (
187+ {
188+ "response_format" : {
189+ "type" : "json_schema" ,
190+ "json_schema" : {
191+ "name" : "structured_output" ,
192+ "strict" : True ,
193+ "schema" : schema ,
194+ }
195+ }
196+ }
197+ )
198+ else :
199+ provider = self .model .split ("/" )[0 ]
200+ raise ValueError (
201+ f"LM provider '{ provider } ' not supported yet, please ensure that"
202+ " they support constrained structured output and fill an issue."
203+ )
142204
143205 if self .api_base :
144206 kwargs .update (
@@ -165,6 +227,11 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
165227 response_str = response ["choices" ][0 ]["message" ]["tool_calls" ][0 ][
166228 "function"
167229 ]["arguments" ]
230+ elif self .model .startswith ("anthropic" ) and schema :
231+ for content_block in response ["content" ]:
232+ if content_block ["type" ] == "tool_use" :
233+ response_str = json .dumps (content_block ["input" ])
234+ break
168235 else :
169236 response_str = response ["choices" ][0 ]["message" ]["content" ].strip ()
170237 if schema :
@@ -174,7 +241,6 @@ async def __call__(self, messages, schema=None, streaming=False, **kwargs):
174241 return json_instance
175242 except Exception as e :
176243 warnings .warn (str (e ))
177- raise e
178244 return None
179245
180246 def _obj_type (self ):
0 commit comments