Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions mcp_server_snowflake/object_manager/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def parse_object(target_object: Any, obj_type: supported_objects):


def initialize_object_manager_tools(server: FastMCP, snowflake_service):
root = snowflake_service.root
supported_objects_list = list(get_args(supported_objects))
object_type_annotation = Annotated[
supported_objects,
Expand Down Expand Up @@ -221,7 +220,8 @@ def create_object_tool(
):
# If string is passed, parse JSON and create object
target_object = parse_object(target_object, object_type)
return create_object(target_object, root, mode)
snowflake_service._ensure_connected()
return create_object(target_object, snowflake_service.root, mode)

@server.tool(
name="drop_object",
Expand All @@ -233,7 +233,8 @@ def drop_object_tool(
if_exists: bool = False,
):
target_object = parse_object(target_object, object_type)
return drop_object(target_object, root, if_exists)
snowflake_service._ensure_connected()
return drop_object(target_object, snowflake_service.root, if_exists)

@server.tool(
name="create_or_alter_object",
Expand All @@ -244,7 +245,8 @@ def create_or_alter_object_tool(
target_object: target_object_annotation,
):
target_object = parse_object(target_object, object_type)
return create_or_alter_object(target_object, root)
snowflake_service._ensure_connected()
return create_or_alter_object(target_object, snowflake_service.root)

@server.tool(
name="describe_object",
Expand All @@ -255,7 +257,8 @@ def describe_object_tool(
target_object: target_object_annotation,
):
target_object = parse_object(target_object, object_type)
return describe_object(target_object, root)
snowflake_service._ensure_connected()
return describe_object(target_object, snowflake_service.root)

@server.tool(
name="list_objects",
Expand Down
40 changes: 14 additions & 26 deletions mcp_server_snowflake/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,16 @@ def __init__(
self._is_spcs_container = is_running_in_spcs_container()

self.unpack_service_specs()
# Persist connection to avoid closing it after each request
self.connection = self._get_persistent_connection()
self.root = Root(self.connection)
# Connection is lazily established on first tool use to avoid
# triggering SSO/Okta auth on MCP server startup.
self.connection = None
self.root = None

def _ensure_connected(self) -> None:
"""Lazily establish the Snowflake connection on first use."""
if self.connection is None:
self.connection = self._get_persistent_connection()
self.root = Root(self.connection)

def unpack_service_specs(self) -> None:
"""
Expand Down Expand Up @@ -241,6 +248,7 @@ def get_api_headers(self) -> Dict[str, str]:
}
else:
# For external environments, we need to use the connection token
self._ensure_connected()
return {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
Expand All @@ -261,6 +269,7 @@ def get_api_host(self) -> str:
"SNOWFLAKE_HOST", self.connection_params.get("account", "")
)
else:
self._ensure_connected()
return self.connection.host

@staticmethod
Expand Down Expand Up @@ -334,6 +343,7 @@ def _get_persistent_connection(
**connection_params,
session_parameters=session_parameters,
client_session_keep_alive=True,
client_store_temporary_credential=True,
paramstyle="qmark",
)
if connection: # Send zero compute query to capture query tag
Expand Down Expand Up @@ -378,29 +388,7 @@ def get_connection(
"""

try:
if self.connection is None:
# Get connection parameters based on environment
if self._is_spcs_container:
logger.info("Using SPCS container OAuth authentication")
connection_params = {
"host": os.getenv("SNOWFLAKE_HOST"),
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"token": get_spcs_container_token(),
"authenticator": "oauth",
}
connection_params = {
k: v for k, v in connection_params.items() if v is not None
}
else:
logger.info("Using external authentication")
connection_params = self.connection_params.copy()

self.connection = connect(
**connection_params,
session_parameters=session_parameters,
client_session_keep_alive=False,
paramstyle="qmark",
)
self._ensure_connected()

cursor = (
self.connection.cursor(DictCursor)
Expand Down
Loading