1
1
"""
2
- Support for Snowflake REST API
2
+ Support for Snowflake REST API
3
3
"""
4
4
5
- from typing import TYPE_CHECKING , Any , List , Optional , Tuple
5
+ import json
6
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
6
7
7
8
import httpx
8
9
9
10
from litellm .secret_managers .main import get_secret_str
10
11
from litellm .types .llms .openai import AllMessageValues
11
- from litellm .types .utils import ModelResponse
12
+ from litellm .types .utils import ChatCompletionMessageToolCall , Function , ModelResponse
12
13
13
14
from ...openai_like .chat .transformation import OpenAIGPTConfig
14
15
22
23
23
24
class SnowflakeConfig (OpenAIGPTConfig ):
24
25
"""
25
- source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
26
+ Reference: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api
27
+
28
+ Snowflake Cortex LLM REST API supports function calling with specific models (e.g., Claude 3.5 Sonnet).
29
+ This config handles transformation between OpenAI format and Snowflake's tool_spec format.
26
30
"""
27
31
28
32
@classmethod
29
33
def get_config (cls ):
30
34
return super ().get_config ()
31
35
32
- def get_supported_openai_params (self , model : str ) -> List :
33
- return ["temperature" , "max_tokens" , "top_p" , "response_format" ]
36
+ def get_supported_openai_params (self , model : str ) -> List [str ]:
37
+ return [
38
+ "temperature" ,
39
+ "max_tokens" ,
40
+ "top_p" ,
41
+ "response_format" ,
42
+ "tools" ,
43
+ "tool_choice" ,
44
+ ]
34
45
35
46
def map_openai_params (
36
47
self ,
@@ -56,6 +67,57 @@ def map_openai_params(
56
67
optional_params [param ] = value
57
68
return optional_params
58
69
70
+ def _transform_tool_calls_from_snowflake_to_openai (
71
+ self , content_list : List [Dict [str , Any ]]
72
+ ) -> Tuple [str , Optional [List [ChatCompletionMessageToolCall ]]]:
73
+ """
74
+ Transform Snowflake tool calls to OpenAI format.
75
+
76
+ Args:
77
+ content_list: Snowflake's content_list array containing text and tool_use items
78
+
79
+ Returns:
80
+ Tuple of (text_content, tool_calls)
81
+
82
+ Snowflake format in content_list:
83
+ {
84
+ "type": "tool_use",
85
+ "tool_use": {
86
+ "tool_use_id": "tooluse_...",
87
+ "name": "get_weather",
88
+ "input": {"location": "Paris"}
89
+ }
90
+ }
91
+
92
+ OpenAI format (returned tool_calls):
93
+ ChatCompletionMessageToolCall(
94
+ id="tooluse_...",
95
+ type="function",
96
+ function=Function(name="get_weather", arguments='{"location": "Paris"}')
97
+ )
98
+ """
99
+ text_content = ""
100
+ tool_calls : List [ChatCompletionMessageToolCall ] = []
101
+
102
+ for idx , content_item in enumerate (content_list ):
103
+ if content_item .get ("type" ) == "text" :
104
+ text_content += content_item .get ("text" , "" )
105
+
106
+ ## TOOL CALLING
107
+ elif content_item .get ("type" ) == "tool_use" :
108
+ tool_use_data = content_item .get ("tool_use" , {})
109
+ tool_call = ChatCompletionMessageToolCall (
110
+ id = tool_use_data .get ("tool_use_id" , "" ),
111
+ type = "function" ,
112
+ function = Function (
113
+ name = tool_use_data .get ("name" , "" ),
114
+ arguments = json .dumps (tool_use_data .get ("input" , {})),
115
+ ),
116
+ )
117
+ tool_calls .append (tool_call )
118
+
119
+ return text_content , tool_calls if tool_calls else None
120
+
59
121
def transform_response (
60
122
self ,
61
123
model : str ,
@@ -71,13 +133,34 @@ def transform_response(
71
133
json_mode : Optional [bool ] = None ,
72
134
) -> ModelResponse :
73
135
response_json = raw_response .json ()
136
+
74
137
logging_obj .post_call (
75
138
input = messages ,
76
139
api_key = "" ,
77
140
original_response = response_json ,
78
141
additional_args = {"complete_input_dict" : request_data },
79
142
)
80
143
144
+ ## RESPONSE TRANSFORMATION
145
+ # Snowflake returns content_list (not content) with tool_use objects
146
+ # We need to transform this to OpenAI's format with content + tool_calls
147
+ if "choices" in response_json and len (response_json ["choices" ]) > 0 :
148
+ choice = response_json ["choices" ][0 ]
149
+ if "message" in choice and "content_list" in choice ["message" ]:
150
+ content_list = choice ["message" ]["content_list" ]
151
+ (
152
+ text_content ,
153
+ tool_calls ,
154
+ ) = self ._transform_tool_calls_from_snowflake_to_openai (content_list )
155
+
156
+ # Update the choice message with OpenAI format
157
+ choice ["message" ]["content" ] = text_content
158
+ if tool_calls :
159
+ choice ["message" ]["tool_calls" ] = tool_calls
160
+
161
+ # Remove Snowflake-specific content_list
162
+ del choice ["message" ]["content_list" ]
163
+
81
164
returned_response = ModelResponse (** response_json )
82
165
83
166
returned_response .model = "snowflake/" + (returned_response .model or "" )
@@ -150,6 +233,95 @@ def get_complete_url(
150
233
151
234
return api_base
152
235
236
+ def _transform_tools (self , tools : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
237
+ """
238
+ Transform OpenAI tool format to Snowflake tool format.
239
+
240
+ Args:
241
+ tools: List of tools in OpenAI format
242
+
243
+ Returns:
244
+ List of tools in Snowflake format
245
+
246
+ OpenAI format:
247
+ {
248
+ "type": "function",
249
+ "function": {
250
+ "name": "get_weather",
251
+ "description": "...",
252
+ "parameters": {...}
253
+ }
254
+ }
255
+
256
+ Snowflake format:
257
+ {
258
+ "tool_spec": {
259
+ "type": "generic",
260
+ "name": "get_weather",
261
+ "description": "...",
262
+ "input_schema": {...}
263
+ }
264
+ }
265
+ """
266
+ snowflake_tools : List [Dict [str , Any ]] = []
267
+ for tool in tools :
268
+ if tool .get ("type" ) == "function" :
269
+ function = tool .get ("function" , {})
270
+ snowflake_tool : Dict [str , Any ] = {
271
+ "tool_spec" : {
272
+ "type" : "generic" ,
273
+ "name" : function .get ("name" ),
274
+ "input_schema" : function .get (
275
+ "parameters" ,
276
+ {"type" : "object" , "properties" : {}},
277
+ ),
278
+ }
279
+ }
280
+ # Add description if present
281
+ if "description" in function :
282
+ snowflake_tool ["tool_spec" ]["description" ] = function [
283
+ "description"
284
+ ]
285
+
286
+ snowflake_tools .append (snowflake_tool )
287
+
288
+ return snowflake_tools
289
+
290
+ def _transform_tool_choice (
291
+ self , tool_choice : Union [str , Dict [str , Any ]]
292
+ ) -> Union [str , Dict [str , Any ]]:
293
+ """
294
+ Transform OpenAI tool_choice format to Snowflake format.
295
+
296
+ Args:
297
+ tool_choice: Tool choice in OpenAI format (str or dict)
298
+
299
+ Returns:
300
+ Tool choice in Snowflake format
301
+
302
+ OpenAI format:
303
+ {"type": "function", "function": {"name": "get_weather"}}
304
+
305
+ Snowflake format:
306
+ {"type": "tool", "name": ["get_weather"]}
307
+
308
+ Note: String values ("auto", "required", "none") pass through unchanged.
309
+ """
310
+ if isinstance (tool_choice , str ):
311
+ # "auto", "required", "none" pass through as-is
312
+ return tool_choice
313
+
314
+ if isinstance (tool_choice , dict ):
315
+ if tool_choice .get ("type" ) == "function" :
316
+ function_name = tool_choice .get ("function" , {}).get ("name" )
317
+ if function_name :
318
+ return {
319
+ "type" : "tool" ,
320
+ "name" : [function_name ], # Snowflake expects array
321
+ }
322
+
323
+ return tool_choice
324
+
153
325
def transform_request (
154
326
self ,
155
327
model : str ,
@@ -160,6 +332,18 @@ def transform_request(
160
332
) -> dict :
161
333
stream : bool = optional_params .pop ("stream" , None ) or False
162
334
extra_body = optional_params .pop ("extra_body" , {})
335
+
336
+ ## TOOL CALLING
337
+ # Transform tools from OpenAI format to Snowflake's tool_spec format
338
+ tools = optional_params .pop ("tools" , None )
339
+ if tools :
340
+ optional_params ["tools" ] = self ._transform_tools (tools )
341
+
342
+ # Transform tool_choice from OpenAI format to Snowflake's tool name array format
343
+ tool_choice = optional_params .pop ("tool_choice" , None )
344
+ if tool_choice :
345
+ optional_params ["tool_choice" ] = self ._transform_tool_choice (tool_choice )
346
+
163
347
return {
164
348
"model" : model ,
165
349
"messages" : messages ,
0 commit comments