Skip to content

Commit 9385e86

Browse files
committed
Genie Dash implement happiest path
1 parent b7db1d1 commit 9385e86

File tree

1 file changed

+42
-33
lines changed

1 file changed

+42
-33
lines changed

dash/pages/genie_api.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from databricks.sdk import WorkspaceClient
55
from databricks.sdk.service.dashboards import GenieMessage
66
import pandas as pd
7-
from typing import Dict
7+
from typing import Dict, List
88

99

1010
# pages/ml_serving_invoke.py
@@ -14,7 +14,7 @@
1414
title='Genie',
1515
name='Genie',
1616
category='Business Intelligence',
17-
icon='material-symbols:model-training'
17+
icon='material-symbols:chat'
1818
)
1919

2020
# Initialize WorkspaceClient with error handling
@@ -60,16 +60,22 @@ def dash_dataframe(df: pd.DataFrame) -> dash.dash_table.DataTable:
6060
return table
6161

6262

63-
def display_message(message: Dict):
64-
if "content" in message:
65-
dcc.Markdown(message["content"])
66-
if "data" in message:
67-
dash_dataframe(message["data"])
68-
if "code" in message:
69-
dcc.Markdown(f'''```sql {message["code"]}```''', className="border rounded p-3")
63+
def format_message_display(chat_history: List[Dict]) -> List[Dict]:
64+
chat_display = []
65+
for message in chat_history:
66+
display = []
67+
if "content" in message:
68+
display.append(dcc.Markdown(message["content"]))
69+
if "data" in message:
70+
display.append(message["data"])
71+
if "code" in message:
72+
display.append(dcc.Markdown(f'''```sql {message["code"]}```''', className="border rounded p-3"))
73+
chat_display.append(html.Div(display, className=f"chat-message {message['role']}-message"))
7074

75+
return chat_display
7176

72-
def get_query_result(statement_id: str) -> pd.DataFrame:
77+
78+
def get_query_result(statement_id: str) -> dash.dash_table.DataTable:
7379
query = w.statement_execution.get_statement(statement_id)
7480
result = query.result.data_array
7581

@@ -79,20 +85,25 @@ def get_query_result(statement_id: str) -> pd.DataFrame:
7985
result.append(chunk.data_array)
8086
next_chunk = chunk.next_chunk_index
8187

82-
return pd.DataFrame(result, columns=[i.name for i in query.manifest.schema.columns])
88+
df = pd.DataFrame(result, columns=[i.name for i in query.manifest.schema.columns])
89+
90+
return dash_dataframe(df)
8391

8492

85-
def process_genie_response(response: GenieMessage):
93+
def process_genie_response(response: GenieMessage, chat_history: List[Dict]) -> List[Dict]:
8694
for i in response.attachments:
8795
if i.text:
8896
message = {"role": "assistant", "content": i.text.content}
89-
display_message(message)
97+
chat_history.append(message)
98+
9099
elif i.query:
91100
data = get_query_result(i.query.statement_id)
92101
message = {
93102
"role": "assistant", "content": i.query.description, "data": data, "code": i.query.query
94103
}
95-
display_message(message)
104+
chat_history.append(message)
105+
106+
return chat_history
96107

97108

98109
def layout():
@@ -106,6 +117,7 @@ def layout():
106117
href="https://www.databricks.com/product/ai-bi",
107118
target="_blank"
108119
),
120+
" ",
109121
html.A(
110122
"API",
111123
href="https://docs.databricks.com/api/workspace/genie",
@@ -149,7 +161,8 @@ def layout():
149161
], className="mb-4"),
150162

151163
# Chat history area
152-
html.Div(id="chat-history-genie", className="mt-4"),
164+
html.Div(id="chat-history", className="mt-4"),
165+
dcc.Store(id='chat-history-store'),
153166
dcc.Store(id='conversation-id'),
154167

155168
# Status/error messages
@@ -233,48 +246,44 @@ def process_genie_response(response):
233246
], fluid=True, className="py-4")
234247

235248
@callback(
236-
[Output("chat-history-genie", "children"),
237-
Output("status-area-genie", "children")],
238-
Input("chat-button", "n_clicks"),
249+
[Output('chat-history-store', 'data', allow_duplicate=True),
250+
Output('chat-history', 'children', allow_duplicate=True)],
251+
Input("chat-button", "n_clicks"),
239252
[State("genie-space-id-input", "value"),
240253
State("conversation-id", "value"),
241-
State("question-input", "value")],
254+
State("question-input", "value"),
255+
State("chat-history-store", "data")],
242256
prevent_initial_call=True
243257
)
244-
def update_chat(n_clicks, genie_space_id, conversation_id, prompt):
258+
def update_chat(n_clicks, genie_space_id, conversation_id, prompt, chat_history):
245259
if not all([genie_space_id, prompt]):
246260
return dash.no_update, dbc.Alert(
247261
"Please fill in all fields",
248262
color="warning"
249263
)
250264

265+
chat_history = chat_history or []
266+
251267
try:
252268
if conversation_id:
253269
conversation = w.genie.create_message_and_wait(
254270
genie_space_id, conversation_id, prompt
255271
)
256-
process_genie_response(conversation)
272+
chat_history = process_genie_response(conversation, chat_history)
257273
else:
258274
conversation = w.genie.start_conversation_and_wait(genie_space_id, prompt)
259275
conversation_id = conversation.conversation_id
260-
process_genie_response(conversation)
276+
chat_history = process_genie_response(conversation, chat_history)
261277

278+
chat_display = format_message_display(chat_history)
262279

263-
return [
264-
html.Div([
265-
dbc.Card(
266-
dbc.CardBody([
267-
html.P(f"Q: {prompt}"),
268-
html.P("A: Processing your question...")
269-
])
270-
)
271-
], className="mb-3")
272-
], None
280+
return chat_history, chat_display
281+
273282
except Exception as e:
274283
return dash.no_update, dbc.Alert(
275284
f"An error occurred: {str(e)}",
276285
color="danger"
277286
)
278287

279288
# Make layout available at module level
280-
__all__ = ['layout']
289+
__all__ = ['layout']

0 commit comments

Comments
 (0)