Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__
.DS_Store
.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Databricks MCP Assistant (with React UI)

A full-stack, Databricks-themed conversational assistant for supply chain queries, powered by OpenAI Agents and Databricks MCP servers. Includes a React chat UI and a FastAPI backend that streams agent responses.

---

## Features
- Conversational chat UI (React) with Databricks red palette
- FastAPI backend with streaming `/chat` endpoint
- Secure Databricks MCP integration
- Example agent logic and tool usage
- Modern UX, easy local development


## Quickstart

### 0. Databricks assets

You can kick start your project with Databricks’ Supply-Chain Optimization Solution Accelerator (or any other accelerator if working in a different industry). Clone this accelerator’s GitHub repo into your Databricks workspace and run the bundled notebooks by running notebook 1:

https://github.com/lara-openai/databricks-supply-chain

These notebooks stand up every asset the Agent will later reach via MCP, from raw enterprise tables and unstructured e-mails to classical ML models and graph workloads.

### 1. Prerequisites
- Python 3.10+
- Node.js 18+
- Databricks credentials in `~/.databrickscfg`
- OpenAI API key
- (Optional) Virtualenv/pyenv for Python isolation

### 2. Install Python Dependencies
```bash
pip install -r requirements.txt
```

### 3. Start the Backend (FastAPI)

To kick off the backend, run:

```bash
python -m uvicorn api_server:app --reload --port 8000
```
- The API will be available at http://localhost:8000
- FastAPI docs: http://localhost:8000/docs

### 4. Start the Frontend (React UI)
In a different terminal, run the following:
```bash
cd ui
npm install
npm run dev
```
- The app will be available at http://localhost:5173

---

## Usage
1. Open [http://localhost:5173](http://localhost:5173) in your browser.
2. Type a supply chain question (e.g., "What are the delays with distribution center 5?") and hit Send.
3. The agent will stream back a response from the Databricks MCP server.

---

## Troubleshooting
- **Port already in use:** Kill old processes with `lsof -ti:8000 | xargs kill -9` (for backend) or change the port.
- **Frontend not loading:** Make sure you ran `npm install` and `npm run dev` in the `ui/` folder.

---

## Customization
- To change the agent's greeting, edit `ui/src/components/ChatUI.jsx`.
- To update backend agent logic, modify `api_server.py`.
- UI styling is in `ui/src/components/ChatUI.css` (Databricks red palette).



Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
FastAPI wrapper that exposes the agent as a streaming `/chat` endpoint.
"""

import os
import asyncio
import logging
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from agents.exceptions import (
InputGuardrailTripwireTriggered,
OutputGuardrailTripwireTriggered,
)
from agents import Agent, Runner, gen_trace_id, trace
from agents.mcp import MCPServerStreamableHttp, MCPServerStreamableHttpParams
from agents.model_settings import ModelSettings
from databricks_mcp import DatabricksOAuthClientProvider
from databricks.sdk import WorkspaceClient

from supply_chain_guardrails import supply_chain_guardrail

CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main")
SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db")
FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db")
DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT")
HTTP_TIMEOUT = 30.0 # seconds

app = FastAPI()

# Allow local dev front‑end
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

class ChatRequest(BaseModel):
message: str


async def build_mcp_servers():
"""Initialise Databricks MCP vector & UC‑function servers."""
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
token = DatabricksOAuthClientProvider(ws).get_token()

base = ws.config.host
vector_url = f"{base}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}"
fn_url = f"{base}/api/2.0/mcp/functions/{FUNCTIONS_PATH}"

async def _proxy_tool(request_json: dict, url: str):
import httpx

headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
resp = await client.post(url, json=request_json, headers=headers)
resp.raise_for_status()
return resp.json()

headers = {"Authorization": f"Bearer {token}"}

servers = [
MCPServerStreamableHttp(
MCPServerStreamableHttpParams(
url=vector_url,
headers=headers,
timeout=HTTP_TIMEOUT,
),
name="vector_search",
client_session_timeout_seconds=60,
),
MCPServerStreamableHttp(
MCPServerStreamableHttpParams(
url=fn_url,
headers=headers,
timeout=HTTP_TIMEOUT,
),
name="uc_functions",
client_session_timeout_seconds=60,
),
]

# Ensure servers are initialized before use
await asyncio.gather(*(s.connect() for s in servers))
return servers


@app.post("/chat")
async def chat_endpoint(req: ChatRequest):
try:
servers = await build_mcp_servers()

agent = Agent(
name="Assistant",
instructions="Use the tools to answer the questions.",
mcp_servers=servers,
model_settings=ModelSettings(tool_choice="required"),
output_guardrails=[supply_chain_guardrail],
)

trace_id = gen_trace_id()

async def agent_stream():
logging.info(f"[AGENT_STREAM] Input message: {req.message}")
try:
with trace(workflow_name="Databricks MCP Example", trace_id=trace_id):
result = await Runner.run(starting_agent=agent, input=req.message)
logging.info(f"[AGENT_STREAM] Raw agent result: {result}")
try:
logging.info(
f"[AGENT_STREAM] RunResult __dict__: {getattr(result, '__dict__', str(result))}"
)
raw_responses = getattr(result, "raw_responses", None)
logging.info(f"[AGENT_STREAM] RunResult raw_responses: {raw_responses}")
except Exception as log_exc:
logging.warning(f"[AGENT_STREAM] Could not log RunResult details: {log_exc}")
yield result.final_output
except InputGuardrailTripwireTriggered:
# Off-topic question denied by guardrail
yield "Sorry, I can only help with supply-chain questions."
except OutputGuardrailTripwireTriggered:
# Out-of-scope answer blocked by guardrail
yield "Sorry, I can only help with supply-chain questions."
except Exception:
logging.exception("[AGENT_STREAM] Exception during agent run")
yield "[ERROR] Exception during agent run. Check backend logs for details."

return StreamingResponse(agent_stream(), media_type="text/plain")

except Exception:
logging.exception("chat_endpoint failed")
return StreamingResponse(
(line.encode() for line in ["Internal server error 🙈"]),
media_type="text/plain",
status_code=500,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
Databricks OAuth client provider for MCP servers.
"""

class DatabricksOAuthClientProvider:
def __init__(self, ws):
self.ws = ws

def get_token(self):
# For Databricks SDK >=0.57.0, token is available as ws.config.token
return self.ws.config.token
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
CLI assistant that uses Databricks MCP Vector Search and UC Functions via the OpenAI Agents SDK.
"""

import asyncio
import os
import httpx
from typing import Dict, Any
from agents import Agent, Runner, function_tool, gen_trace_id, trace
from agents.exceptions import (
InputGuardrailTripwireTriggered,
OutputGuardrailTripwireTriggered,
)
from agents.model_settings import ModelSettings
from databricks_mcp import DatabricksOAuthClientProvider
from databricks.sdk import WorkspaceClient
from supply_chain_guardrails import supply_chain_guardrail

CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main")
SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db")
FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db")
DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT")
HTTP_TIMEOUT = 30.0 # seconds


async def _databricks_ctx():
"""Return (workspace, PAT token, base_url)."""
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
token = DatabricksOAuthClientProvider(ws).get_token()
return ws, token, ws.config.host


@function_tool
async def vector_search(query: str) -> Dict[str, Any]:
"""Query Databricks MCP Vector Search index."""
ws, token, base_url = await _databricks_ctx()
url = f"{base_url}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}"
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
resp = await client.post(url, json={"query": query}, headers=headers)
resp.raise_for_status()
return resp.json()


@function_tool
async def uc_function(function_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""Invoke a Databricks Unity Catalog function with parameters."""
ws, token, base_url = await _databricks_ctx()
url = f"{base_url}/api/2.0/mcp/functions/{FUNCTIONS_PATH}"
headers = {"Authorization": f"Bearer {token}"}
payload = {"function": function_name, "params": params}
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
resp = await client.post(url, json=payload, headers=headers)
resp.raise_for_status()
return resp.json()


async def run_agent():
agent = Agent(
name="Assistant",
instructions="You are a supply-chain assistant for Databricks MCP; you must answer **only** questions that are **strictly** about supply-chain data, logistics, inventory, procurement, demand forecasting, etc; for every answer you must call one of the registered tools; if the user asks anything not related to supply chain, reply **exactly** with 'Sorry, I can only help with supply-chain questions'.",
tools=[vector_search, uc_function],
model_settings=ModelSettings(model="gpt-4o", tool_choice="required"),
output_guardrails=[supply_chain_guardrail],
)

print("Databricks MCP assistant ready. Type a question or 'exit' to quit.")

while True:
user_input = input("You: ").strip()
if user_input.lower() in {"exit", "quit"}:
break

trace_id = gen_trace_id()
with trace(workflow_name="Databricks MCP Agent", trace_id=trace_id):
try:
result = await Runner.run(starting_agent=agent, input=user_input)
print("Assistant:", result.final_output)
except InputGuardrailTripwireTriggered:
print("Assistant: Sorry, I can only help with supply-chain questions.")
except OutputGuardrailTripwireTriggered:
print("Assistant: Sorry, I can only help with supply-chain questions.")


def main():
asyncio.run(run_agent())


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi==0.115.13
uvicorn==0.34.3
pydantic==2.11.7
databricks-sdk==0.57.0
httpx==0.28.1
openai-agents==0.0.19
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Output guardrail that blocks answers not related to supply-chain topics.
"""
from __future__ import annotations

from pydantic import BaseModel
from agents import Agent, Runner, GuardrailFunctionOutput
from agents import output_guardrail
from agents.run_context import RunContextWrapper

class SupplyChainCheckOutput(BaseModel):
reasoning: str
is_supply_chain: bool


guardrail_agent = Agent(
name="Supply-chain check",
instructions=(
"Check if the text is within the domain of supply-chain analytics and operations "
"Return JSON strictly matching the SupplyChainCheckOutput schema"
),
output_type=SupplyChainCheckOutput,
)


@output_guardrail
async def supply_chain_guardrail(
ctx: RunContextWrapper, agent: Agent, output
) -> GuardrailFunctionOutput:
"""Output guardrail that blocks non-supply-chain answers"""
text = output if isinstance(output, str) else getattr(output, "response", str(output))
result = await Runner.run(guardrail_agent, text, context=ctx.context)
return GuardrailFunctionOutput(
output_info=result.final_output,
tripwire_triggered=not result.final_output.is_supply_chain,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*

node_modules
dist
dist-ssr
*.local

# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
Loading