44from databricks .sdk import WorkspaceClient
55from databricks .sdk .service .dashboards import GenieMessage
66import pandas as pd
7- from typing import Dict
7+ from typing import Dict , List
88
99
1010# pages/ml_serving_invoke.py
1414 title = 'Genie' ,
1515 name = 'Genie' ,
1616 category = 'Business Intelligence' ,
17- icon = 'material-symbols:model-training '
17+ icon = 'material-symbols:chat '
1818)
1919
2020# Initialize WorkspaceClient with error handling
@@ -60,16 +60,22 @@ def dash_dataframe(df: pd.DataFrame) -> dash.dash_table.DataTable:
6060 return table
6161
6262
63- def display_message (message : Dict ):
64- if "content" in message :
65- dcc .Markdown (message ["content" ])
66- if "data" in message :
67- dash_dataframe (message ["data" ])
68- if "code" in message :
69- dcc .Markdown (f'''```sql { message ["code" ]} ```''' , className = "border rounded p-3" )
63+ def format_message_display (chat_history : List [Dict ]) -> List [Dict ]:
64+ chat_display = []
65+ for message in chat_history :
66+ display = []
67+ if "content" in message :
68+ display .append (dcc .Markdown (message ["content" ]))
69+ if "data" in message :
70+ display .append (message ["data" ])
71+ if "code" in message :
72+ display .append (dcc .Markdown (f'''```sql { message ["code" ]} ```''' , className = "border rounded p-3" ))
73+ chat_display .append (html .Div (display , className = f"chat-message { message ['role' ]} -message" ))
7074
75+ return chat_display
7176
72- def get_query_result (statement_id : str ) -> pd .DataFrame :
77+
78+ def get_query_result (statement_id : str ) -> dash .dash_table .DataTable :
7379 query = w .statement_execution .get_statement (statement_id )
7480 result = query .result .data_array
7581
@@ -79,20 +85,25 @@ def get_query_result(statement_id: str) -> pd.DataFrame:
7985 result .append (chunk .data_array )
8086 next_chunk = chunk .next_chunk_index
8187
82- return pd .DataFrame (result , columns = [i .name for i in query .manifest .schema .columns ])
88+ df = pd .DataFrame (result , columns = [i .name for i in query .manifest .schema .columns ])
89+
90+ return dash_dataframe (df )
8391
8492
85- def process_genie_response (response : GenieMessage ) :
93+ def process_genie_response (response : GenieMessage , chat_history : List [ Dict ]) -> List [ Dict ] :
8694 for i in response .attachments :
8795 if i .text :
8896 message = {"role" : "assistant" , "content" : i .text .content }
89- display_message (message )
97+ chat_history .append (message )
98+
9099 elif i .query :
91100 data = get_query_result (i .query .statement_id )
92101 message = {
93102 "role" : "assistant" , "content" : i .query .description , "data" : data , "code" : i .query .query
94103 }
95- display_message (message )
104+ chat_history .append (message )
105+
106+ return chat_history
96107
97108
98109def layout ():
@@ -106,6 +117,7 @@ def layout():
106117 href = "https://www.databricks.com/product/ai-bi" ,
107118 target = "_blank"
108119 ),
120+ " " ,
109121 html .A (
110122 "API" ,
111123 href = "https://docs.databricks.com/api/workspace/genie" ,
@@ -149,7 +161,8 @@ def layout():
149161 ], className = "mb-4" ),
150162
151163 # Chat history area
152- html .Div (id = "chat-history-genie" , className = "mt-4" ),
164+ html .Div (id = "chat-history" , className = "mt-4" ),
165+ dcc .Store (id = 'chat-history-store' ),
153166 dcc .Store (id = 'conversation-id' ),
154167
155168 # Status/error messages
@@ -233,48 +246,44 @@ def process_genie_response(response):
233246 ], fluid = True , className = "py-4" )
234247
235248@callback (
236- [Output (" chat-history-genie" , "children" ),
237- Output ("status-area-genie" , " children" )],
238- Input ("chat-button" , "n_clicks" ),
249+ [Output (' chat-history-store' , 'data' , allow_duplicate = True ),
250+ Output ('chat-history' , ' children' , allow_duplicate = True )],
251+ Input ("chat-button" , "n_clicks" ),
239252 [State ("genie-space-id-input" , "value" ),
240253 State ("conversation-id" , "value" ),
241- State ("question-input" , "value" )],
254+ State ("question-input" , "value" ),
255+ State ("chat-history-store" , "data" )],
242256 prevent_initial_call = True
243257)
244- def update_chat (n_clicks , genie_space_id , conversation_id , prompt ):
258+ def update_chat (n_clicks , genie_space_id , conversation_id , prompt , chat_history ):
245259 if not all ([genie_space_id , prompt ]):
246260 return dash .no_update , dbc .Alert (
247261 "Please fill in all fields" ,
248262 color = "warning"
249263 )
250264
265+ chat_history = chat_history or []
266+
251267 try :
252268 if conversation_id :
253269 conversation = w .genie .create_message_and_wait (
254270 genie_space_id , conversation_id , prompt
255271 )
256- process_genie_response (conversation )
272+ chat_history = process_genie_response (conversation , chat_history )
257273 else :
258274 conversation = w .genie .start_conversation_and_wait (genie_space_id , prompt )
259275 conversation_id = conversation .conversation_id
260- process_genie_response (conversation )
276+ chat_history = process_genie_response (conversation , chat_history )
261277
278+ chat_display = format_message_display (chat_history )
262279
263- return [
264- html .Div ([
265- dbc .Card (
266- dbc .CardBody ([
267- html .P (f"Q: { prompt } " ),
268- html .P ("A: Processing your question..." )
269- ])
270- )
271- ], className = "mb-3" )
272- ], None
280+ return chat_history , chat_display
281+
273282 except Exception as e :
274283 return dash .no_update , dbc .Alert (
275284 f"An error occurred: { str (e )} " ,
276285 color = "danger"
277286 )
278287
279288# Make layout available at module level
280- __all__ = ['layout' ]
289+ __all__ = ['layout' ]
0 commit comments