Skip to content

Commit 9383ffc

Browse files
Merge pull request #15221 from jatorre/feat/snowflake-tools-clean
feat(snowflake): add function calling support for Snowflake Cortex REST API
2 parents 30d24fa + df232a7 commit 9383ffc

File tree

3 files changed

+573
-7
lines changed

3 files changed

+573
-7
lines changed

litellm/llms/snowflake/chat/transformation.py

Lines changed: 190 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""
2-
Support for Snowflake REST API
2+
Support for Snowflake REST API
33
"""
44

5-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
5+
import json
6+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
67

78
import httpx
89

910
from litellm.secret_managers.main import get_secret_str
1011
from litellm.types.llms.openai import AllMessageValues
11-
from litellm.types.utils import ModelResponse
12+
from litellm.types.utils import ChatCompletionMessageToolCall, Function, ModelResponse
1213

1314
from ...openai_like.chat.transformation import OpenAIGPTConfig
1415

@@ -22,15 +23,25 @@
2223

2324
class SnowflakeConfig(OpenAIGPTConfig):
2425
"""
25-
source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
26+
Reference: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api
27+
28+
Snowflake Cortex LLM REST API supports function calling with specific models (e.g., Claude 3.5 Sonnet).
29+
This config handles transformation between OpenAI format and Snowflake's tool_spec format.
2630
"""
2731

2832
@classmethod
2933
def get_config(cls):
3034
return super().get_config()
3135

32-
def get_supported_openai_params(self, model: str) -> List:
33-
return ["temperature", "max_tokens", "top_p", "response_format"]
36+
def get_supported_openai_params(self, model: str) -> List[str]:
37+
return [
38+
"temperature",
39+
"max_tokens",
40+
"top_p",
41+
"response_format",
42+
"tools",
43+
"tool_choice",
44+
]
3445

3546
def map_openai_params(
3647
self,
@@ -56,6 +67,57 @@ def map_openai_params(
5667
optional_params[param] = value
5768
return optional_params
5869

70+
def _transform_tool_calls_from_snowflake_to_openai(
71+
self, content_list: List[Dict[str, Any]]
72+
) -> Tuple[str, Optional[List[ChatCompletionMessageToolCall]]]:
73+
"""
74+
Transform Snowflake tool calls to OpenAI format.
75+
76+
Args:
77+
content_list: Snowflake's content_list array containing text and tool_use items
78+
79+
Returns:
80+
Tuple of (text_content, tool_calls)
81+
82+
Snowflake format in content_list:
83+
{
84+
"type": "tool_use",
85+
"tool_use": {
86+
"tool_use_id": "tooluse_...",
87+
"name": "get_weather",
88+
"input": {"location": "Paris"}
89+
}
90+
}
91+
92+
OpenAI format (returned tool_calls):
93+
ChatCompletionMessageToolCall(
94+
id="tooluse_...",
95+
type="function",
96+
function=Function(name="get_weather", arguments='{"location": "Paris"}')
97+
)
98+
"""
99+
text_content = ""
100+
tool_calls: List[ChatCompletionMessageToolCall] = []
101+
102+
for idx, content_item in enumerate(content_list):
103+
if content_item.get("type") == "text":
104+
text_content += content_item.get("text", "")
105+
106+
## TOOL CALLING
107+
elif content_item.get("type") == "tool_use":
108+
tool_use_data = content_item.get("tool_use", {})
109+
tool_call = ChatCompletionMessageToolCall(
110+
id=tool_use_data.get("tool_use_id", ""),
111+
type="function",
112+
function=Function(
113+
name=tool_use_data.get("name", ""),
114+
arguments=json.dumps(tool_use_data.get("input", {})),
115+
),
116+
)
117+
tool_calls.append(tool_call)
118+
119+
return text_content, tool_calls if tool_calls else None
120+
59121
def transform_response(
60122
self,
61123
model: str,
@@ -71,13 +133,34 @@ def transform_response(
71133
json_mode: Optional[bool] = None,
72134
) -> ModelResponse:
73135
response_json = raw_response.json()
136+
74137
logging_obj.post_call(
75138
input=messages,
76139
api_key="",
77140
original_response=response_json,
78141
additional_args={"complete_input_dict": request_data},
79142
)
80143

144+
## RESPONSE TRANSFORMATION
145+
# Snowflake returns content_list (not content) with tool_use objects
146+
# We need to transform this to OpenAI's format with content + tool_calls
147+
if "choices" in response_json and len(response_json["choices"]) > 0:
148+
choice = response_json["choices"][0]
149+
if "message" in choice and "content_list" in choice["message"]:
150+
content_list = choice["message"]["content_list"]
151+
(
152+
text_content,
153+
tool_calls,
154+
) = self._transform_tool_calls_from_snowflake_to_openai(content_list)
155+
156+
# Update the choice message with OpenAI format
157+
choice["message"]["content"] = text_content
158+
if tool_calls:
159+
choice["message"]["tool_calls"] = tool_calls
160+
161+
# Remove Snowflake-specific content_list
162+
del choice["message"]["content_list"]
163+
81164
returned_response = ModelResponse(**response_json)
82165

83166
returned_response.model = "snowflake/" + (returned_response.model or "")
@@ -150,6 +233,95 @@ def get_complete_url(
150233

151234
return api_base
152235

236+
def _transform_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
237+
"""
238+
Transform OpenAI tool format to Snowflake tool format.
239+
240+
Args:
241+
tools: List of tools in OpenAI format
242+
243+
Returns:
244+
List of tools in Snowflake format
245+
246+
OpenAI format:
247+
{
248+
"type": "function",
249+
"function": {
250+
"name": "get_weather",
251+
"description": "...",
252+
"parameters": {...}
253+
}
254+
}
255+
256+
Snowflake format:
257+
{
258+
"tool_spec": {
259+
"type": "generic",
260+
"name": "get_weather",
261+
"description": "...",
262+
"input_schema": {...}
263+
}
264+
}
265+
"""
266+
snowflake_tools: List[Dict[str, Any]] = []
267+
for tool in tools:
268+
if tool.get("type") == "function":
269+
function = tool.get("function", {})
270+
snowflake_tool: Dict[str, Any] = {
271+
"tool_spec": {
272+
"type": "generic",
273+
"name": function.get("name"),
274+
"input_schema": function.get(
275+
"parameters",
276+
{"type": "object", "properties": {}},
277+
),
278+
}
279+
}
280+
# Add description if present
281+
if "description" in function:
282+
snowflake_tool["tool_spec"]["description"] = function[
283+
"description"
284+
]
285+
286+
snowflake_tools.append(snowflake_tool)
287+
288+
return snowflake_tools
289+
290+
def _transform_tool_choice(
291+
self, tool_choice: Union[str, Dict[str, Any]]
292+
) -> Union[str, Dict[str, Any]]:
293+
"""
294+
Transform OpenAI tool_choice format to Snowflake format.
295+
296+
Args:
297+
tool_choice: Tool choice in OpenAI format (str or dict)
298+
299+
Returns:
300+
Tool choice in Snowflake format
301+
302+
OpenAI format:
303+
{"type": "function", "function": {"name": "get_weather"}}
304+
305+
Snowflake format:
306+
{"type": "tool", "name": ["get_weather"]}
307+
308+
Note: String values ("auto", "required", "none") pass through unchanged.
309+
"""
310+
if isinstance(tool_choice, str):
311+
# "auto", "required", "none" pass through as-is
312+
return tool_choice
313+
314+
if isinstance(tool_choice, dict):
315+
if tool_choice.get("type") == "function":
316+
function_name = tool_choice.get("function", {}).get("name")
317+
if function_name:
318+
return {
319+
"type": "tool",
320+
"name": [function_name], # Snowflake expects array
321+
}
322+
323+
return tool_choice
324+
153325
def transform_request(
154326
self,
155327
model: str,
@@ -160,6 +332,18 @@ def transform_request(
160332
) -> dict:
161333
stream: bool = optional_params.pop("stream", None) or False
162334
extra_body = optional_params.pop("extra_body", {})
335+
336+
## TOOL CALLING
337+
# Transform tools from OpenAI format to Snowflake's tool_spec format
338+
tools = optional_params.pop("tools", None)
339+
if tools:
340+
optional_params["tools"] = self._transform_tools(tools)
341+
342+
# Transform tool_choice from OpenAI format to Snowflake's tool name array format
343+
tool_choice = optional_params.pop("tool_choice", None)
344+
if tool_choice:
345+
optional_params["tool_choice"] = self._transform_tool_choice(tool_choice)
346+
163347
return {
164348
"model": model,
165349
"messages": messages,

tests/llm_translation/test_snowflake.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
load_dotenv()
77
import pytest
88

9-
from litellm import completion, acompletion
9+
from litellm import completion, acompletion, responses
1010
from litellm.exceptions import APIConnectionError
1111

1212
@pytest.mark.parametrize("sync_mode", [True, False])
@@ -87,3 +87,70 @@ async def test_chat_completion_snowflake_stream(sync_mode):
8787
raise # Re-raise if it's a different APIConnectionError
8888
except Exception as e:
8989
pytest.fail(f"Error occurred: {e}")
90+
91+
92+
@pytest.mark.skip(reason="Requires Snowflake credentials - run manually when needed")
93+
def test_snowflake_tool_calling_responses_api():
94+
"""
95+
Test Snowflake tool calling with Responses API.
96+
Requires SNOWFLAKE_JWT and SNOWFLAKE_ACCOUNT_ID environment variables.
97+
"""
98+
import litellm
99+
100+
# Skip if credentials not available
101+
if not os.getenv("SNOWFLAKE_JWT") or not os.getenv("SNOWFLAKE_ACCOUNT_ID"):
102+
pytest.skip("Snowflake credentials not available")
103+
104+
litellm.drop_params = False # We now support tools!
105+
106+
tools = [
107+
{
108+
"type": "function",
109+
"name": "get_weather",
110+
"description": "Get the current weather in a given location",
111+
"parameters": {
112+
"type": "object",
113+
"properties": {
114+
"location": {
115+
"type": "string",
116+
"description": "The city and state, e.g. San Francisco, CA",
117+
}
118+
},
119+
"required": ["location"],
120+
},
121+
}
122+
]
123+
124+
try:
125+
# Test with tool_choice to force tool use
126+
response = responses(
127+
model="snowflake/claude-3-5-sonnet",
128+
input="What's the weather in Paris?",
129+
tools=tools,
130+
tool_choice={"type": "function", "function": {"name": "get_weather"}},
131+
max_output_tokens=200,
132+
)
133+
134+
assert response is not None
135+
assert hasattr(response, "output")
136+
assert len(response.output) > 0
137+
138+
# Verify tool call was made
139+
tool_call_found = False
140+
for item in response.output:
141+
if hasattr(item, "type") and item.type == "function_call":
142+
tool_call_found = True
143+
assert item.name == "get_weather"
144+
assert hasattr(item, "arguments")
145+
print(f"✅ Tool call detected: {item.name}({item.arguments})")
146+
break
147+
148+
assert tool_call_found, "Expected tool call but none was found"
149+
150+
except APIConnectionError as e:
151+
if "JWT token is invalid" in str(e):
152+
pytest.skip("Invalid Snowflake JWT token")
153+
elif "Application failed to respond" in str(e) or "502" in str(e):
154+
pytest.skip(f"Snowflake API unavailable: {e}")
155+
else:
156+
raise

0 commit comments

Comments
 (0)