Skip to content

Commit 5f14829

Browse files
committed
Move Agent code to cookbook folder
1 parent 0abaa02 commit 5f14829

File tree

22 files changed

+3636
-0
lines changed

22 files changed

+3636
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__pycache__
2+
.DS_Store
3+
.python-version
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Databricks MCP Assistant (with React UI)
2+
3+
A full-stack, Databricks-themed conversational assistant for supply chain queries, powered by OpenAI Agents and Databricks MCP servers. Includes a React chat UI and a FastAPI backend that streams agent responses.
4+
5+
---
6+
7+
## Features
8+
- Conversational chat UI (React) with Databricks red palette
9+
- FastAPI backend with streaming `/chat` endpoint
10+
- Secure Databricks MCP integration
11+
- Example agent logic and tool usage
12+
- Modern UX, easy local development
13+
14+
---
15+
16+
## Project Structure
17+
```
18+
/ (root)
19+
├── main.py # Example CLI agent runner
20+
├── api_server.py # FastAPI backend for chat UI
21+
├── requirements.txt # Python dependencies
22+
├── ui/ # React frontend (Vite)
23+
│ ├── src/components/ChatUI.jsx, ChatUI.css, ...
24+
│ └── ...
25+
└── README.md # (this file)
26+
```
27+
28+
---
29+
30+
## Quickstart
31+
32+
### 0. Databricks assets
33+
34+
You can kick start your project with Databricks’ Supply-Chain Optimization Solution Accelerator (or any other accelerator if working in a different industry). Clone this accelerator’s GitHub repo into your Databricks workspace and run the bundled notebooks by running notebook 1:
35+
36+
https://github.com/lararachidi/agent-supply-chain/blob/main/README.md
37+
38+
These notebooks stand up every asset the Agent will later reach via MCP, from raw enterprise tables and unstructured e-mails to classical ML models and graph workloads.
39+
40+
### 1. Prerequisites
41+
- Python 3.10+
42+
- Node.js 18+
43+
- Databricks credentials in `~/.databrickscfg`
44+
- (Optional) Virtualenv/pyenv for Python isolation
45+
46+
### 2. Install Python Dependencies
47+
```bash
48+
pip install -r requirements.txt
49+
```
50+
51+
### 3. Start the Backend (FastAPI)
52+
53+
To kick off the backend, run:
54+
55+
```bash
56+
python -m uvicorn api_server:app --reload --port 8000
57+
```
58+
- The API will be available at http://localhost:8000
59+
- FastAPI docs: http://localhost:8000/docs
60+
61+
### 4. Start the Frontend (React UI)
62+
In a different terminal, run the following:
63+
```bash
64+
cd ui
65+
npm install
66+
npm run dev
67+
```
68+
- The app will be available at http://localhost:5173
69+
70+
---
71+
72+
## Usage
73+
1. Open [http://localhost:5173](http://localhost:5173) in your browser.
74+
2. Type a supply chain question (e.g., "What are the delays with distribution center 5?") and hit Send.
75+
3. The agent will stream back a response from the Databricks MCP server.
76+
77+
---
78+
79+
## Troubleshooting
80+
- **Port already in use:** Kill old processes with `lsof -ti:8000 | xargs kill -9` (for backend) or change the port.
81+
- **Frontend not loading:** Make sure you ran `npm install` and `npm run dev` in the `ui/` folder.
82+
83+
---
84+
85+
## Customization
86+
- To change the agent's greeting, edit `ui/src/components/ChatUI.jsx`.
87+
- To update backend agent logic, modify `api_server.py`.
88+
- UI styling is in `ui/src/components/ChatUI.css` (Databricks red palette).
89+
90+
---
91+
92+
## Credits & References
93+
- Inspired by [Databricks Supply Chain Solution Accelerator](https://www.databricks.com/solutions/accelerators/supply-chain-distribution-optimization)
94+
- Uses [openai-agents-python](https://github.com/openai/openai-agents-python)
95+
- Databricks MCP integration via [databricks.sdk](https://github.com/databricks/databricks-sdk-py)
96+
- Supply-chain scope enforced by a simple LLM guardrail (see `supply_chain_guardrails.py`)
97+
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
FastAPI wrapper that exposes the agent as a streaming `/chat` endpoint.
3+
"""
4+
5+
import os
6+
import asyncio
7+
import logging
8+
from fastapi import FastAPI
9+
from fastapi.responses import StreamingResponse
10+
from fastapi.middleware.cors import CORSMiddleware
11+
from pydantic import BaseModel
12+
from agents.exceptions import (
13+
InputGuardrailTripwireTriggered,
14+
OutputGuardrailTripwireTriggered,
15+
)
16+
from agents import Agent, Runner, gen_trace_id, trace
17+
from agents.mcp import MCPServerStreamableHttp, MCPServerStreamableHttpParams
18+
from agents.model_settings import ModelSettings
19+
from databricks_mcp import DatabricksOAuthClientProvider
20+
from databricks.sdk import WorkspaceClient
21+
22+
from supply_chain_guardrails import supply_chain_guardrail
23+
24+
CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main")
25+
SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db")
26+
FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db")
27+
DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT")
28+
HTTP_TIMEOUT = 30.0 # seconds
29+
30+
app = FastAPI()
31+
32+
# Allow local dev front‑end
33+
app.add_middleware(
34+
CORSMiddleware,
35+
allow_origins=["http://localhost:5173"],
36+
allow_credentials=True,
37+
allow_methods=["*"],
38+
allow_headers=["*"],
39+
)
40+
41+
class ChatRequest(BaseModel):
42+
message: str
43+
44+
45+
async def build_mcp_servers():
46+
"""Initialise Databricks MCP vector & UC‑function servers."""
47+
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
48+
token = DatabricksOAuthClientProvider(ws).get_token()
49+
50+
base = ws.config.host
51+
vector_url = f"{base}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}"
52+
fn_url = f"{base}/api/2.0/mcp/functions/{FUNCTIONS_PATH}"
53+
54+
async def _proxy_tool(request_json: dict, url: str):
55+
import httpx
56+
57+
headers = {"Authorization": f"Bearer {token}"}
58+
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
59+
resp = await client.post(url, json=request_json, headers=headers)
60+
resp.raise_for_status()
61+
return resp.json()
62+
63+
headers = {"Authorization": f"Bearer {token}"}
64+
65+
servers = [
66+
MCPServerStreamableHttp(
67+
MCPServerStreamableHttpParams(
68+
url=vector_url,
69+
headers=headers,
70+
timeout=HTTP_TIMEOUT,
71+
),
72+
name="vector_search",
73+
client_session_timeout_seconds=60,
74+
),
75+
MCPServerStreamableHttp(
76+
MCPServerStreamableHttpParams(
77+
url=fn_url,
78+
headers=headers,
79+
timeout=HTTP_TIMEOUT,
80+
),
81+
name="uc_functions",
82+
client_session_timeout_seconds=60,
83+
),
84+
]
85+
86+
# Ensure servers are initialized before use
87+
await asyncio.gather(*(s.connect() for s in servers))
88+
return servers
89+
90+
91+
@app.post("/chat")
92+
async def chat_endpoint(req: ChatRequest):
93+
try:
94+
servers = await build_mcp_servers()
95+
96+
agent = Agent(
97+
name="Assistant",
98+
instructions="Use the tools to answer the questions.",
99+
mcp_servers=servers,
100+
model_settings=ModelSettings(tool_choice="required"),
101+
output_guardrails=[supply_chain_guardrail],
102+
)
103+
104+
trace_id = gen_trace_id()
105+
106+
async def agent_stream():
107+
logging.info(f"[AGENT_STREAM] Input message: {req.message}")
108+
try:
109+
with trace(workflow_name="Databricks MCP Example", trace_id=trace_id):
110+
result = await Runner.run(starting_agent=agent, input=req.message)
111+
logging.info(f"[AGENT_STREAM] Raw agent result: {result}")
112+
try:
113+
logging.info(
114+
f"[AGENT_STREAM] RunResult __dict__: {getattr(result, '__dict__', str(result))}"
115+
)
116+
raw_responses = getattr(result, "raw_responses", None)
117+
logging.info(f"[AGENT_STREAM] RunResult raw_responses: {raw_responses}")
118+
except Exception as log_exc:
119+
logging.warning(f"[AGENT_STREAM] Could not log RunResult details: {log_exc}")
120+
yield result.final_output
121+
except InputGuardrailTripwireTriggered:
122+
# Off-topic question denied by guardrail
123+
yield "Sorry, I can only help with supply-chain questions."
124+
except OutputGuardrailTripwireTriggered:
125+
# Out-of-scope answer blocked by guardrail
126+
yield "Sorry, I can only help with supply-chain questions."
127+
except Exception:
128+
logging.exception("[AGENT_STREAM] Exception during agent run")
129+
yield "[ERROR] Exception during agent run. Check backend logs for details."
130+
131+
return StreamingResponse(agent_stream(), media_type="text/plain")
132+
133+
except Exception:
134+
logging.exception("chat_endpoint failed")
135+
return StreamingResponse(
136+
(line.encode() for line in ["Internal server error 🙈"]),
137+
media_type="text/plain",
138+
status_code=500,
139+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
Databricks OAuth client provider for MCP servers.
3+
"""
4+
5+
class DatabricksOAuthClientProvider:
6+
def __init__(self, ws):
7+
self.ws = ws
8+
9+
def get_token(self):
10+
# For Databricks SDK >=0.57.0, token is available as ws.config.token
11+
return self.ws.config.token
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
CLI assistant that uses Databricks MCP Vector Search and UC Functions via the OpenAI Agents SDK.
3+
"""
4+
5+
import asyncio
6+
import os
7+
import httpx
8+
from typing import Dict, Any
9+
from agents import Agent, Runner, function_tool, gen_trace_id, trace
10+
from agents.exceptions import (
11+
InputGuardrailTripwireTriggered,
12+
OutputGuardrailTripwireTriggered,
13+
)
14+
from agents.model_settings import ModelSettings
15+
from databricks_mcp import DatabricksOAuthClientProvider
16+
from databricks.sdk import WorkspaceClient
17+
from supply_chain_guardrails import supply_chain_guardrail
18+
19+
CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main")
20+
SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db")
21+
FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db")
22+
DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT")
23+
HTTP_TIMEOUT = 30.0 # seconds
24+
25+
26+
async def _databricks_ctx():
27+
"""Return (workspace, PAT token, base_url)."""
28+
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
29+
token = DatabricksOAuthClientProvider(ws).get_token()
30+
return ws, token, ws.config.host
31+
32+
33+
@function_tool
34+
async def vector_search(query: str) -> Dict[str, Any]:
35+
"""Query Databricks MCP Vector Search index."""
36+
ws, token, base_url = await _databricks_ctx()
37+
url = f"{base_url}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}"
38+
headers = {"Authorization": f"Bearer {token}"}
39+
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
40+
resp = await client.post(url, json={"query": query}, headers=headers)
41+
resp.raise_for_status()
42+
return resp.json()
43+
44+
45+
@function_tool
46+
async def uc_function(function_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
47+
"""Invoke a Databricks Unity Catalog function with parameters."""
48+
ws, token, base_url = await _databricks_ctx()
49+
url = f"{base_url}/api/2.0/mcp/functions/{FUNCTIONS_PATH}"
50+
headers = {"Authorization": f"Bearer {token}"}
51+
payload = {"function": function_name, "params": params}
52+
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
53+
resp = await client.post(url, json=payload, headers=headers)
54+
resp.raise_for_status()
55+
return resp.json()
56+
57+
58+
async def run_agent():
59+
agent = Agent(
60+
name="Assistant",
61+
instructions="You are a supply-chain assistant for Databricks MCP; you must answer **only** questions that are **strictly** about supply-chain data, logistics, inventory, procurement, demand forecasting, etc; for every answer you must call one of the registered tools; if the user asks anything not related to supply chain, reply **exactly** with 'Sorry, I can only help with supply-chain questions'.",
62+
tools=[vector_search, uc_function],
63+
model_settings=ModelSettings(model="gpt-4o", tool_choice="required"),
64+
output_guardrails=[supply_chain_guardrail],
65+
)
66+
67+
print("Databricks MCP assistant ready. Type a question or 'exit' to quit.")
68+
69+
while True:
70+
user_input = input("You: ").strip()
71+
if user_input.lower() in {"exit", "quit"}:
72+
break
73+
74+
trace_id = gen_trace_id()
75+
with trace(workflow_name="Databricks MCP Agent", trace_id=trace_id):
76+
try:
77+
result = await Runner.run(starting_agent=agent, input=user_input)
78+
print("Assistant:", result.final_output)
79+
except InputGuardrailTripwireTriggered:
80+
print("Assistant: Sorry, I can only help with supply-chain questions.")
81+
except OutputGuardrailTripwireTriggered:
82+
print("Assistant: Sorry, I can only help with supply-chain questions.")
83+
84+
85+
def main():
86+
asyncio.run(run_agent())
87+
88+
89+
if __name__ == "__main__":
90+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
fastapi==0.115.13
2+
uvicorn==0.34.3
3+
pydantic==2.11.7
4+
databricks-sdk==0.57.0
5+
httpx==0.28.1
6+
openai-agents==0.0.19
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
Output guardrail that blocks answers not related to supply-chain topics.
3+
"""
4+
from __future__ import annotations
5+
6+
from pydantic import BaseModel
7+
from agents import Agent, Runner, GuardrailFunctionOutput
8+
from agents import output_guardrail
9+
from agents.run_context import RunContextWrapper
10+
11+
class SupplyChainCheckOutput(BaseModel):
12+
reasoning: str
13+
is_supply_chain: bool
14+
15+
16+
guardrail_agent = Agent(
17+
name="Supply-chain check",
18+
instructions=(
19+
"Check if the text is within the domain of supply-chain analytics and operations "
20+
"Return JSON strictly matching the SupplyChainCheckOutput schema"
21+
),
22+
output_type=SupplyChainCheckOutput,
23+
)
24+
25+
26+
@output_guardrail
27+
async def supply_chain_guardrail(
28+
ctx: RunContextWrapper, agent: Agent, output
29+
) -> GuardrailFunctionOutput:
30+
"""Output guardrail that blocks non-supply-chain answers"""
31+
text = output if isinstance(output, str) else getattr(output, "response", str(output))
32+
result = await Runner.run(guardrail_agent, text, context=ctx.context)
33+
return GuardrailFunctionOutput(
34+
output_info=result.final_output,
35+
tripwire_triggered=not result.final_output.is_supply_chain,
36+
)

0 commit comments

Comments
 (0)