Skip to content

Commit 070743e

Browse files
authored
MCP Supply Chain Cookbook Improvements (#1944)
1 parent 375f987 commit 070743e

File tree

24 files changed

+4021
-118
lines changed

24 files changed

+4021
-118
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__pycache__
2+
.DS_Store
3+
.python-version
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Databricks MCP Assistant (with React UI)
2+
3+
A full-stack, Databricks-themed conversational assistant for supply chain queries, powered by OpenAI Agents and Databricks MCP servers. Includes a React chat UI and a FastAPI backend that streams agent responses.
4+
5+
---
6+
7+
## Features
8+
- Conversational chat UI (React) with Databricks red palette
9+
- FastAPI backend with streaming `/chat` endpoint
10+
- Secure Databricks MCP integration
11+
- Example agent logic and tool usage
12+
- Modern UX, easy local development
13+
14+
15+
## Quickstart
16+
17+
### 0. Databricks assets
18+
19+
You can kick start your project with Databricks’ Supply-Chain Optimization Solution Accelerator (or any other accelerator if working in a different industry). Clone this accelerator’s GitHub repo into your Databricks workspace and run the bundled notebooks by running notebook 1:
20+
21+
https://github.com/lara-openai/databricks-supply-chain
22+
23+
These notebooks stand up every asset the Agent will later reach via MCP, from raw enterprise tables and unstructured e-mails to classical ML models and graph workloads.
24+
25+
### 1. Prerequisites
26+
- Python 3.10+
27+
- Node.js 18+
28+
- Databricks credentials in `~/.databrickscfg`
29+
- OpenAI API key
30+
- (Optional) Virtualenv/pyenv for Python isolation
31+
32+
### 2. Install Python Dependencies
33+
```bash
34+
pip install -r requirements.txt
35+
```
36+
37+
### 3. Start the Backend (FastAPI)
38+
39+
To kick off the backend, run:
40+
41+
```bash
42+
python -m uvicorn api_server:app --reload --port 8000
43+
```
44+
- The API will be available at http://localhost:8000
45+
- FastAPI docs: http://localhost:8000/docs
46+
47+
### 4. Start the Frontend (React UI)
48+
In a different terminal, run the following:
49+
```bash
50+
cd ui
51+
npm install
52+
npm run dev
53+
```
54+
- The app will be available at http://localhost:5173
55+
56+
---
57+
58+
## Usage
59+
1. Open [http://localhost:5173](http://localhost:5173) in your browser.
60+
2. Type a supply chain question (e.g., "What are the delays with distribution center 5?") and hit Send.
61+
3. The agent will stream back a response from the Databricks MCP server.
62+
63+
---
64+
65+
## Troubleshooting
66+
- **Port already in use:** Kill old processes with `lsof -ti:8000 | xargs kill -9` (for backend) or change the port.
67+
- **Frontend not loading:** Make sure you ran `npm install` and `npm run dev` in the `ui/` folder.
68+
69+
---
70+
71+
## Customization
72+
- To change the agent's greeting, edit `ui/src/components/ChatUI.jsx`.
73+
- To update backend agent logic, modify `api_server.py`.
74+
- UI styling is in `ui/src/components/ChatUI.css` (Databricks red palette).
75+
76+
77+
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
FastAPI wrapper that exposes the agent as a streaming `/chat` endpoint.
3+
"""
4+
5+
import os
6+
import asyncio
7+
import logging
8+
from fastapi import FastAPI
9+
from fastapi.responses import StreamingResponse
10+
from fastapi.middleware.cors import CORSMiddleware
11+
from pydantic import BaseModel
12+
from agents.exceptions import (
13+
InputGuardrailTripwireTriggered,
14+
OutputGuardrailTripwireTriggered,
15+
)
16+
from agents import Agent, Runner, gen_trace_id, trace
17+
from agents.mcp import MCPServerStreamableHttp, MCPServerStreamableHttpParams
18+
from agents.model_settings import ModelSettings
19+
from databricks_mcp import DatabricksOAuthClientProvider
20+
from databricks.sdk import WorkspaceClient
21+
22+
from supply_chain_guardrails import supply_chain_guardrail
23+
24+
CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main")
25+
SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db")
26+
FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db")
27+
DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT")
28+
HTTP_TIMEOUT = 30.0 # seconds
29+
30+
app = FastAPI()
31+
32+
# Allow local dev front‑end
33+
app.add_middleware(
34+
CORSMiddleware,
35+
allow_origins=["http://localhost:5173"],
36+
allow_credentials=True,
37+
allow_methods=["*"],
38+
allow_headers=["*"],
39+
)
40+
41+
class ChatRequest(BaseModel):
42+
message: str
43+
44+
45+
async def build_mcp_servers():
46+
"""Initialise Databricks MCP vector & UC‑function servers."""
47+
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
48+
token = DatabricksOAuthClientProvider(ws).get_token()
49+
50+
base = ws.config.host
51+
vector_url = f"{base}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}"
52+
fn_url = f"{base}/api/2.0/mcp/functions/{FUNCTIONS_PATH}"
53+
54+
async def _proxy_tool(request_json: dict, url: str):
55+
import httpx
56+
57+
headers = {"Authorization": f"Bearer {token}"}
58+
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
59+
resp = await client.post(url, json=request_json, headers=headers)
60+
resp.raise_for_status()
61+
return resp.json()
62+
63+
headers = {"Authorization": f"Bearer {token}"}
64+
65+
servers = [
66+
MCPServerStreamableHttp(
67+
MCPServerStreamableHttpParams(
68+
url=vector_url,
69+
headers=headers,
70+
timeout=HTTP_TIMEOUT,
71+
),
72+
name="vector_search",
73+
client_session_timeout_seconds=60,
74+
),
75+
MCPServerStreamableHttp(
76+
MCPServerStreamableHttpParams(
77+
url=fn_url,
78+
headers=headers,
79+
timeout=HTTP_TIMEOUT,
80+
),
81+
name="uc_functions",
82+
client_session_timeout_seconds=60,
83+
),
84+
]
85+
86+
# Ensure servers are initialized before use
87+
await asyncio.gather(*(s.connect() for s in servers))
88+
return servers
89+
90+
91+
@app.post("/chat")
92+
async def chat_endpoint(req: ChatRequest):
93+
try:
94+
servers = await build_mcp_servers()
95+
96+
agent = Agent(
97+
name="Assistant",
98+
instructions="Use the tools to answer the questions.",
99+
mcp_servers=servers,
100+
model_settings=ModelSettings(tool_choice="required"),
101+
output_guardrails=[supply_chain_guardrail],
102+
)
103+
104+
trace_id = gen_trace_id()
105+
106+
async def agent_stream():
107+
logging.info(f"[AGENT_STREAM] Input message: {req.message}")
108+
try:
109+
with trace(workflow_name="Databricks MCP Example", trace_id=trace_id):
110+
result = await Runner.run(starting_agent=agent, input=req.message)
111+
logging.info(f"[AGENT_STREAM] Raw agent result: {result}")
112+
try:
113+
logging.info(
114+
f"[AGENT_STREAM] RunResult __dict__: {getattr(result, '__dict__', str(result))}"
115+
)
116+
raw_responses = getattr(result, "raw_responses", None)
117+
logging.info(f"[AGENT_STREAM] RunResult raw_responses: {raw_responses}")
118+
except Exception as log_exc:
119+
logging.warning(f"[AGENT_STREAM] Could not log RunResult details: {log_exc}")
120+
yield result.final_output
121+
except InputGuardrailTripwireTriggered:
122+
# Off-topic question denied by guardrail
123+
yield "Sorry, I can only help with supply-chain questions."
124+
except OutputGuardrailTripwireTriggered:
125+
# Out-of-scope answer blocked by guardrail
126+
yield "Sorry, I can only help with supply-chain questions."
127+
except Exception:
128+
logging.exception("[AGENT_STREAM] Exception during agent run")
129+
yield "[ERROR] Exception during agent run. Check backend logs for details."
130+
131+
return StreamingResponse(agent_stream(), media_type="text/plain")
132+
133+
except Exception:
134+
logging.exception("chat_endpoint failed")
135+
return StreamingResponse(
136+
(line.encode() for line in ["Internal server error 🙈"]),
137+
media_type="text/plain",
138+
status_code=500,
139+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
Databricks OAuth client provider for MCP servers.
3+
"""
4+
5+
class DatabricksOAuthClientProvider:
6+
def __init__(self, ws):
7+
self.ws = ws
8+
9+
def get_token(self):
10+
# For Databricks SDK >=0.57.0, token is available as ws.config.token
11+
return self.ws.config.token
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
CLI assistant that uses Databricks MCP Vector Search and UC Functions via the OpenAI Agents SDK.
3+
"""
4+
5+
import asyncio
6+
import os
7+
import httpx
8+
from typing import Dict, Any
9+
from agents import Agent, Runner, function_tool, gen_trace_id, trace
10+
from agents.exceptions import (
11+
InputGuardrailTripwireTriggered,
12+
OutputGuardrailTripwireTriggered,
13+
)
14+
from agents.model_settings import ModelSettings
15+
from databricks_mcp import DatabricksOAuthClientProvider
16+
from databricks.sdk import WorkspaceClient
17+
from supply_chain_guardrails import supply_chain_guardrail
18+
19+
CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main")
20+
SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db")
21+
FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db")
22+
DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT")
23+
HTTP_TIMEOUT = 30.0 # seconds
24+
25+
26+
async def _databricks_ctx():
27+
"""Return (workspace, PAT token, base_url)."""
28+
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
29+
token = DatabricksOAuthClientProvider(ws).get_token()
30+
return ws, token, ws.config.host
31+
32+
33+
@function_tool
34+
async def vector_search(query: str) -> Dict[str, Any]:
35+
"""Query Databricks MCP Vector Search index."""
36+
ws, token, base_url = await _databricks_ctx()
37+
url = f"{base_url}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}"
38+
headers = {"Authorization": f"Bearer {token}"}
39+
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
40+
resp = await client.post(url, json={"query": query}, headers=headers)
41+
resp.raise_for_status()
42+
return resp.json()
43+
44+
45+
@function_tool
46+
async def uc_function(function_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
47+
"""Invoke a Databricks Unity Catalog function with parameters."""
48+
ws, token, base_url = await _databricks_ctx()
49+
url = f"{base_url}/api/2.0/mcp/functions/{FUNCTIONS_PATH}"
50+
headers = {"Authorization": f"Bearer {token}"}
51+
payload = {"function": function_name, "params": params}
52+
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
53+
resp = await client.post(url, json=payload, headers=headers)
54+
resp.raise_for_status()
55+
return resp.json()
56+
57+
58+
async def run_agent():
59+
agent = Agent(
60+
name="Assistant",
61+
instructions="You are a supply-chain assistant for Databricks MCP; you must answer **only** questions that are **strictly** about supply-chain data, logistics, inventory, procurement, demand forecasting, etc; for every answer you must call one of the registered tools; if the user asks anything not related to supply chain, reply **exactly** with 'Sorry, I can only help with supply-chain questions'.",
62+
tools=[vector_search, uc_function],
63+
model_settings=ModelSettings(model="gpt-4o", tool_choice="required"),
64+
output_guardrails=[supply_chain_guardrail],
65+
)
66+
67+
print("Databricks MCP assistant ready. Type a question or 'exit' to quit.")
68+
69+
while True:
70+
user_input = input("You: ").strip()
71+
if user_input.lower() in {"exit", "quit"}:
72+
break
73+
74+
trace_id = gen_trace_id()
75+
with trace(workflow_name="Databricks MCP Agent", trace_id=trace_id):
76+
try:
77+
result = await Runner.run(starting_agent=agent, input=user_input)
78+
print("Assistant:", result.final_output)
79+
except InputGuardrailTripwireTriggered:
80+
print("Assistant: Sorry, I can only help with supply-chain questions.")
81+
except OutputGuardrailTripwireTriggered:
82+
print("Assistant: Sorry, I can only help with supply-chain questions.")
83+
84+
85+
def main():
86+
asyncio.run(run_agent())
87+
88+
89+
if __name__ == "__main__":
90+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
fastapi==0.115.13
2+
uvicorn==0.34.3
3+
pydantic==2.11.7
4+
databricks-sdk==0.57.0
5+
httpx==0.28.1
6+
openai-agents==0.0.19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
Output guardrail that blocks answers not related to supply-chain topics.
3+
"""
4+
from __future__ import annotations
5+
6+
from pydantic import BaseModel
7+
from agents import Agent, Runner, GuardrailFunctionOutput
8+
from agents import output_guardrail
9+
from agents.run_context import RunContextWrapper
10+
11+
class SupplyChainCheckOutput(BaseModel):
12+
reasoning: str
13+
is_supply_chain: bool
14+
15+
16+
guardrail_agent = Agent(
17+
name="Supply-chain check",
18+
instructions=(
19+
"Check if the text is within the domain of supply-chain analytics and operations "
20+
"Return JSON strictly matching the SupplyChainCheckOutput schema"
21+
),
22+
output_type=SupplyChainCheckOutput,
23+
)
24+
25+
26+
@output_guardrail
27+
async def supply_chain_guardrail(
28+
ctx: RunContextWrapper, agent: Agent, output
29+
) -> GuardrailFunctionOutput:
30+
"""Output guardrail that blocks non-supply-chain answers"""
31+
text = output if isinstance(output, str) else getattr(output, "response", str(output))
32+
result = await Runner.run(guardrail_agent, text, context=ctx.context)
33+
return GuardrailFunctionOutput(
34+
output_info=result.final_output,
35+
tripwire_triggered=not result.final_output.is_supply_chain,
36+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Logs
2+
logs
3+
*.log
4+
npm-debug.log*
5+
yarn-debug.log*
6+
yarn-error.log*
7+
pnpm-debug.log*
8+
lerna-debug.log*
9+
10+
node_modules
11+
dist
12+
dist-ssr
13+
*.local
14+
15+
# Editor directories and files
16+
.vscode/*
17+
!.vscode/extensions.json
18+
.idea
19+
.DS_Store
20+
*.suo
21+
*.ntvs*
22+
*.njsproj
23+
*.sln
24+
*.sw?

0 commit comments

Comments
 (0)