Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@

# COMMAND ----------

# MAGIC %pip install databricks-mcp
# MAGIC

# COMMAND ----------

# MAGIC %run ../_resources/01-setup

# COMMAND ----------
Expand All @@ -60,7 +65,8 @@
"system_prompt": "Your job is to provide customer help. call the tool to answer.",
"llm_endpoint_name": LLM_ENDPOINT_NAME,
"max_history_messages": 20,
"retriever_config": None
"retriever_config": None,
"mcp_server_urls": []
}
try:
with open('agent_config.yaml', 'w') as f:
Expand Down
87 changes: 75 additions & 12 deletions product_demos/Data-Science/ai-agent/02-agent-eval/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict, Union
from uuid import uuid4
Expand All @@ -10,17 +11,16 @@
DatabricksFunctionClient,
set_uc_function_client,
)

from databricks_mcp import DatabricksMCPClient
from databricks.sdk import WorkspaceClient
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
convert_to_openai_messages,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool

from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
Expand All @@ -35,18 +35,48 @@
ResponsesAgentStreamEvent,
)

# Enable MLflow LangChain auto-trace
# Enable LangChain autolog
mlflow.langchain.autolog()

# Required to use Unity Catalog UDFs as tools
set_uc_function_client(DatabricksFunctionClient())


class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
custom_inputs: Optional[dict[str, Any]]
custom_outputs: Optional[dict[str, Any]]


# Generic schema for MCP tools (allows any input)
GENERIC_SCHEMA = {
"title": "MCPToolArgs",
"type": "object",
"properties": {},
"additionalProperties": True
}


class MCPToolWrapper(BaseTool):
"""Wrap a Databricks MCP tool as a LangChain BaseTool"""

def __init__(self, name: str, description: str, server_url: str, ws_client: WorkspaceClient):
super().__init__(name=name, description=description, args_schema=GENERIC_SCHEMA)
# store server info internally (not a Pydantic field)
self._tool_data = {
"server_url": server_url,
"ws_client": ws_client,
}

def _run(self, **kwargs) -> str:
client = DatabricksMCPClient(
server_url=self._tool_data["server_url"],
workspace_client=self._tool_data["ws_client"]
)
response = client.call_tool(self.name, kwargs)
return "".join([c.text for c in response.content])


def create_tool_calling_agent(
model: LanguageModelLike,
tools: Union[ToolNode, Sequence[BaseTool]],
Expand Down Expand Up @@ -85,15 +115,20 @@ def __init__(
llm_endpoint_name: str = "databricks-meta-llama-3-70b-instruct",
system_prompt: Optional[str] = None,
retriever_config: Optional[dict] = None,
mcp_server_urls: Optional[Sequence[str]] = None,
max_history_messages: int = 20,
):
self.llm_endpoint_name = llm_endpoint_name
self.system_prompt = system_prompt
self.max_history_messages = max_history_messages

# Initialize LLM
self.llm = ChatDatabricks(endpoint=llm_endpoint_name)

# Load Unity Catalog tools
self.tools: list[BaseTool] = UCFunctionToolkit(function_names=list(uc_tool_names)).tools

# Add retriever if configured
if retriever_config:
self.tools.append(
VectorSearchRetrieverTool(
Expand All @@ -104,8 +139,25 @@ def __init__(
)
)

# Add MCP tools from URLs
if mcp_server_urls:
ws_client = WorkspaceClient()
for url in mcp_server_urls:
try:
client = DatabricksMCPClient(server_url=url, workspace_client=ws_client)
tool_defs = client.list_tools()
for t in tool_defs:
self.tools.append(MCPToolWrapper(t.name, t.description or t.name, url, ws_client))
print(f"Loaded MCP tools from {url}: {[t.name for t in self.tools if isinstance(t, MCPToolWrapper)]}")
except Exception as e:
print(f"Failed to load MCP server {url}: {e}")

# Create agent graph
self.agent = create_tool_calling_agent(self.llm, self.tools, system_prompt)

# -----------------------
# LangGraph Responses mapping
# -----------------------
def _responses_to_cc(self, message: dict[str, Any]) -> list[dict[str, Any]]:
msg_type = message.get("type")
if msg_type == "function_call":
Expand Down Expand Up @@ -160,6 +212,9 @@ def _langchain_to_responses(self, messages: list[dict[str, Any]]) -> list[dict[s
)]
return []

# -----------------------
# Predict methods
# -----------------------
@mlflow.trace(span_type=SpanType.AGENT)
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
outputs = [
Expand All @@ -169,17 +224,16 @@ def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

@mlflow.trace(span_type=SpanType.AGENT)
def predict_stream(
self, request: ResponsesAgentRequest,
) -> Generator[ResponsesAgentStreamEvent, None, None]:
def predict_stream(self, request: ResponsesAgentRequest) -> Generator[ResponsesAgentStreamEvent, None, None]:
cc_msgs = []
mlflow.update_current_trace(request_preview=request.input[0].content)
for msg in request.input:
cc_msgs.extend(self._responses_to_cc(msg.model_dump()))

# Limit history to the most recent max_history_messages
# Limit history
if len(cc_msgs) > self.max_history_messages:
cc_msgs = cc_msgs[-self.max_history_messages:]

for event in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates", "messages"]):
if event[0] == "updates":
for node_data in event[1].values():
Expand All @@ -194,7 +248,10 @@ def predict_stream(
)
except Exception:
pass


# -----------------------
# Resource tracking
# -----------------------
def get_resources(self):
res = [DatabricksServingEndpoint(endpoint_name=self.llm.endpoint)]
for t in self.tools:
Expand All @@ -204,18 +261,24 @@ def get_resources(self):
res.append(DatabricksFunction(function_name=t.uc_function_name))
return res

# -----------------------
# Helper to list loaded tools
# -----------------------
def list_tools(self) -> list[str]:
return [t.name for t in self.tools]

# Load configuration values from YAML
# ==========================
# Instantiate from config
# ==========================
model_config = ModelConfig(development_config="../02-agent-eval/agent_config.yaml")

# Instantiate agent
AGENT = LangGraphResponsesAgent(
uc_tool_names=model_config.get("uc_tool_names"),
llm_endpoint_name=model_config.get("llm_endpoint_name"),
system_prompt=model_config.get("system_prompt"),
retriever_config=model_config.get("retriever_config"),
mcp_server_urls=model_config.get("mcp_server_urls"),
max_history_messages=model_config.get("max_history_messages"),
)

# Register agent with MLflow for inference
mlflow.models.set_model(AGENT)
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
config_version_name: better_prompt
config_version_name: model_with_mcp
input_example:
- content: Give me the orders for [email protected]
role: user
llm_endpoint_name: databricks-claude-3-7-sonnet
max_history_messages: 20
mcp_server_urls:
- https://adb-984752964297111.11.azuredatabricks.net/api/2.0/mcp/functions/system/ai
retriever_config: null
system_prompt: You are a telco assistant. Call the appropriate tool to help the user
with billing, support, or account info. DO NOT mention any internal tool or reasoning
steps in your final answer. Do not say according to records or imply that you are
looking up information.
uc_tool_names:
- main_build.dbdemos_ai_agent.*
- main_build.dbdemos_ai_agent.*
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@

# COMMAND ----------

# MAGIC %pip install databricks-mcp

# COMMAND ----------

from agent import AGENT

#Let's try our retriever to make sure we know have access to the wifi router pdf guide
Expand Down
125 changes: 125 additions & 0 deletions product_demos/Data-Science/ai-agent/mcp/agent_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Databricks notebook source
# MAGIC %md
# MAGIC ## Exposing your tools from the MCP server
# MAGIC
# MAGIC If you want to expose your tools or data so that other agents would be able to access it, then you could leverage MCP servers. Databricks already provides [managed MCP servers](https://docs.databricks.com/aws/en/generative-ai/mcp/managed-mcp#available-managed-servers), so that you do not need to start from scratch.
# MAGIC
# MAGIC
# MAGIC These managed servers include the UC functions, vector search indexes, and Genie spaces. In our example, we will show how you can provide the agent, the tools from the Databricks Managed MCP servers such as the UC system ai functions, but in similar fashion you can configure it for any other either Databricks managed MCP server or custom.

# COMMAND ----------

# MAGIC %pip install -U -qqqq mlflow>=3.1.1 langchain langgraph databricks-langchain pydantic databricks-agents databricks-mcp unitycatalog-langchain[databricks] uv databricks-feature-engineering==0.12.1
# MAGIC dbutils.library.restartPython()

# COMMAND ----------

# MAGIC %run ../_resources/01-setup

# COMMAND ----------

# MAGIC %md
# MAGIC ## Databricks Managed MCP
# MAGIC
# MAGIC To use the UC system ai function from the Databricks Managed MCP server, we'd just need to add the relevant url in the [agent_config.yaml](https://adb-984752964297111.11.azuredatabricks.net/editor/files/2254899697811646?o=984752964297111) file.
# MAGIC
# MAGIC - **Unity Catalog system ai functions**: https://{workspace-hostname}/api/2.0/mcp/functions/system/ai

# COMMAND ----------

from databricks.sdk import WorkspaceClient
import os, sys, yaml, mlflow
import nest_asyncio
nest_asyncio.apply()

# --- Paths ---
agent_eval_path = os.path.abspath(os.path.join(os.getcwd(), "../02-agent-eval"))
sys.path.append(agent_eval_path)
conf_path = os.path.join(agent_eval_path, "agent_config.yaml")

# --- Use Databricks SDK to detect workspace URL ---
ws = WorkspaceClient()
workspace_url = ws.config.host.rstrip("/")
mcp_url = f"{workspace_url}/api/2.0/mcp/functions/system/ai"

# ==========================
# Update config for MCP
# ==========================
try:
config = yaml.safe_load(open(conf_path))
config["config_version_name"] = "model_with_mcp"
config["mcp_server_urls"] = [mcp_url]

with open(conf_path, "w") as f:
yaml.safe_dump(config, f, sort_keys=False)

except Exception as e:
print(f"Skipped MCP update: {e}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Instantiate the agent
# MAGIC
# MAGIC Next, we will instantiate the agent and to make sure we have provided the tools from the managed MCP server, we will check all the available tools that the agent has. In the list below, besides the UC functions we have created, we will see also the tools from MCP server which has the UC system ai functions i.e **system__ai__python_exec**.

# COMMAND ----------

# ==========================
# Instantiate the agent
# ==========================
from agent import LangGraphResponsesAgent
import mlflow.models

model_config = mlflow.models.ModelConfig(development_config=conf_path)

AGENT = LangGraphResponsesAgent(
uc_tool_names=model_config.get("uc_tool_names"),
llm_endpoint_name=model_config.get("llm_endpoint_name"),
system_prompt=model_config.get("system_prompt"),
mcp_server_urls=model_config.get("mcp_server_urls"),
max_history_messages=model_config.get("max_history_messages"),
)

print("✅ Available tools:", AGENT.list_tools())

# COMMAND ----------

# MAGIC %md
# MAGIC ## Test the agent in the AI Playground

# COMMAND ----------

# MAGIC %md
# MAGIC To test the agent, you would just need to choose the end of your choice and add the neccessary tools. In our case we will add the tools from the managed MCP server option, by selecting the system ai functions in the UC function toolbox. Similarly, you could add any tool of your choice, also if you'd have custom or external MCP servers.

# COMMAND ----------

# MAGIC %md
# MAGIC <img src="https://raw.githubusercontent.com/databricks-demos/dbdemos-resources/main/images/product/ai-agent/pg-mcp-img1.png" width="800px">

# COMMAND ----------

# MAGIC %md When we start to ask questions, we will see that the agent has properly loaded the tool from the Managed MCP server we added above.

# COMMAND ----------

# MAGIC %md
# MAGIC <img src="https://raw.githubusercontent.com/databricks-demos/dbdemos-resources/main/images/product/ai-agent/pg-mcp-img2.png" width="800">

# COMMAND ----------

# MAGIC %md Now we can start exploring more...

# COMMAND ----------

# MAGIC %md
# MAGIC
# MAGIC <img src="https://raw.githubusercontent.com/databricks-demos/dbdemos-resources/main/images/product/ai-agent/pg-mcp-img3.png" width="800">
# MAGIC

# COMMAND ----------

# MAGIC %md
# MAGIC ## Next steps
# MAGIC If you would like to further explore other MCP server options in Databricks, please refer to [Databricks MCP documentation](https://docs.databricks.com/aws/en/generative-ai/mcp/).