Skip to content

Commit 0394278

Browse files
authored
Merge pull request #56 from antonbricks/ttl-sql-injection-genie
Implement TTL in SQL connections, mitigate SQL Injections, & expand Genie - Streamlit, Dash
2 parents 680bc09 + 9cb95bc commit 0394278

File tree

17 files changed

+174
-41
lines changed

17 files changed

+174
-41
lines changed

dash/pages/tables_edit.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,27 @@ def read_table(table_name: str, conn) -> pd.DataFrame:
3333

3434
def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
3535
with conn.cursor() as cursor:
36-
rows = list(df.itertuples(index=False))
37-
values = ",".join([f"({','.join(map(repr, row))})" for row in rows])
38-
cursor.execute(f"INSERT OVERWRITE {table_name} VALUES {values}")
36+
rows = list(df.itertuples(index=False, name=None))
37+
if not rows:
38+
return
39+
40+
cols = list(df.columns)
41+
num_cols = len(cols)
42+
params = {}
43+
values_sql_parts = []
44+
p = 0
45+
for row in rows:
46+
ph = []
47+
for v in row:
48+
key = f"p{p}"
49+
ph.append(f":{key}")
50+
params[key] = v
51+
p += 1
52+
values_sql_parts.append("(" + ",".join(ph) + ")")
53+
54+
values_sql = ",".join(values_sql_parts)
55+
col_list_sql = ",".join(cols)
56+
cursor.execute(f"INSERT OVERWRITE {table_name} ({col_list_sql}) VALUES {values_sql}", params)
3957

4058
def layout():
4159
return dbc.Container([
@@ -110,9 +128,27 @@ def read_table(table_name: str, conn) -> pd.DataFrame:
110128
111129
def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
112130
with conn.cursor() as cursor:
113-
rows = list(df.itertuples(index=False))
114-
values = ",".join([f"({','.join(map(repr, row))})" for row in rows])
115-
cursor.execute(f"INSERT OVERWRITE {table_name} VALUES {values}")
131+
rows = list(df.itertuples(index=False, name=None))
132+
if not rows:
133+
return
134+
135+
cols = list(df.columns)
136+
num_cols = len(cols)
137+
params = {}
138+
values_sql_parts = []
139+
p = 0
140+
for row in rows:
141+
ph = []
142+
for v in row:
143+
key = f"p{p}"
144+
ph.append(f":{key}")
145+
params[key] = v
146+
p += 1
147+
values_sql_parts.append("(" + ",".join(ph) + ")")
148+
149+
values_sql = ",".join(values_sql_parts)
150+
col_list_sql = ",".join(cols)
151+
cursor.execute(f"INSERT OVERWRITE {table_name} ({col_list_sql}) VALUES {values_sql}", params)
116152
117153
http_path_input = "/sql/1.0/warehouses/xxxxxx"
118154
table_name = "catalog.schema.table"

docs/docs/dash/bi/genie_api.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ sidebar_position: 2
44

55
# Chat with a Genie Space
66

7-
This app uses the [AI/BI Genie](https://www.databricks.com/product/ai-bi) [Conversations API](https://docs.databricks.com/api/workspace/genie) to let users ask questions about your data for instant insights.
7+
This app uses the [AI/BI Genie](https://www.databricks.com/product/ai-bi) [Conversations API](https://docs.databricks.com/api/workspace/genie) to let users ask questions about your data for instant insights (answers and table-like output). Visualizations aren't yet supported in the API.
88

99
## Code snippet
1010

@@ -33,7 +33,7 @@ def process_genie_response(response):
3333
print(f"A: {i.query.description}")
3434
print(f"Data: {data}")
3535
print(f"Generated code: {i.query.query}")
36-
36+
3737

3838
# Configuration
3939
w = WorkspaceClient()

docs/docs/dash/tables/tables_edit.mdx

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,27 @@ def read_table(table_name: str, conn) -> pd.DataFrame:
3131

3232
def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
3333
with conn.cursor() as cursor:
34-
rows = list(df.itertuples(index=False))
35-
values = ",".join([f"({','.join(map(repr, row))})" for row in rows])
36-
cursor.execute(f"INSERT OVERWRITE {table_name} VALUES {values}")
34+
rows = list(df.itertuples(index=False, name=None))
35+
if not rows:
36+
return
37+
38+
cols = list(df.columns)
39+
num_cols = len(cols)
40+
params = {}
41+
values_sql_parts = []
42+
p = 0
43+
for row in rows:
44+
ph = []
45+
for v in row:
46+
key = f"p{p}"
47+
ph.append(f":{key}")
48+
params[key] = v
49+
p += 1
50+
values_sql_parts.append("(" + ",".join(ph) + ")")
51+
52+
values_sql = ",".join(values_sql_parts)
53+
col_list_sql = ",".join(cols)
54+
cursor.execute(f"INSERT OVERWRITE {table_name} ({col_list_sql}) VALUES {values_sql}", params)
3755

3856
http_path_input = "/sql/1.0/warehouses/xxxxxx"
3957
table_name = "catalog.schema.table"

docs/docs/streamlit/authentication/users_obo.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_user_token():
2020
user_token = headers["X-Forwarded-Access-Token"]
2121
return user_token
2222

23-
@st.cache_resource
23+
@st.cache_resource(ttl=300, show_spinner=True)
2424
def connect_with_obo(http_path, user_token):
2525
return sql.connect(
2626
server_hostname=cfg.host,

docs/docs/streamlit/bi/genie_api.mdx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ sidebar_position: 1
44

55
# Chat with a Genie Space
66

7-
This app uses the [AI/BI Genie](https://www.databricks.com/product/ai-bi) [Conversations API](https://docs.databricks.com/api/workspace/genie) to let users ask questions about your data for instant insights.
7+
This app uses the [AI/BI Genie](https://www.databricks.com/product/ai-bi) [Conversations API](https://docs.databricks.com/api/workspace/genie) to let users ask questions about your data for instant insights (answers and table-like output). You are also able to collect their feedback on the responses. Visualizations aren't yet supported in the API.
88

99
## Code snippet
1010

@@ -13,6 +13,7 @@ Refer to the Streamlit Cookbook Genie source code for the full implementation.
1313
```python title="app.py"
1414
import streamlit as st
1515
from databricks.sdk import WorkspaceClient
16+
from databricks.sdk.service.dashboards import GenieFeedbackRating
1617
import pandas as pd
1718

1819
w = WorkspaceClient()
@@ -39,6 +40,15 @@ def get_query_result(statement_id):
3940
)
4041

4142

43+
def collect_feedback(message_id: str):
44+
rating = st.feedback("thumbs", key=f"feedback_{message_id}")
45+
mapping = {1: GenieFeedbackRating.POSITIVE, 0: GenieFeedbackRating.NEGATIVE}
46+
if rating and message["message_id"]:
47+
w.genie.send_message_feedback(
48+
genie_space_id, st.session_state.conversation_id, message["message_id"], mapping[rating]
49+
)
50+
51+
4252
def process_genie_response(response):
4353
for i in response.attachments:
4454
if i.text:
@@ -50,6 +60,7 @@ def process_genie_response(response):
5060
"role": "assistant", "content": i.query.description, "data": data, "code": i.query.query
5161
}
5262
display_message(message)
63+
collect_feedback(response.message_id)
5364

5465

5566
if prompt := st.chat_input("Ask your question..."):

docs/docs/streamlit/tables/tables_edit.mdx

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ from databricks.sdk.core import Config
1818
cfg = Config() # Set the DATABRICKS_HOST environment variable when running locally
1919

2020

21-
@st.cache_resource(ttl="1h")
21+
@st.cache_resource(ttl=300, show_spinner=True)
2222
def get_connection(http_path):
2323
return sql.connect(
2424
server_hostname=cfg.host,
@@ -36,13 +36,30 @@ def read_table(table_name: str, conn) -> pd.DataFrame:
3636
def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
3737
progress = st.empty()
3838
with conn.cursor() as cursor:
39-
rows = list(df.itertuples(index=False))
40-
values = ",".join([f"({','.join(map(repr, row))})" for row in rows])
39+
rows = list(df.itertuples(index=False, name=None))
40+
if not rows:
41+
return
42+
43+
cols = list(df.columns)
44+
num_cols = len(cols)
45+
params = {}
46+
values_sql_parts = []
47+
p = 0
48+
for row in rows:
49+
ph = []
50+
for v in row:
51+
key = f"p{p}"
52+
ph.append(f":{key}")
53+
params[key] = v
54+
p += 1
55+
values_sql_parts.append("(" + ",".join(ph) + ")")
56+
57+
values_sql = ",".join(values_sql_parts)
58+
col_list_sql = ",".join(cols)
59+
4160
with progress:
4261
st.info("Calling Databricks SQL...")
43-
cursor.execute(f"INSERT OVERWRITE {table_name} VALUES {values}")
44-
progress.empty()
45-
st.success("Changes saved")
62+
cursor.execute(f"INSERT OVERWRITE {table_name} ({col_list_sql}) VALUES {values_sql}", params)
4663

4764

4865
http_path_input = st.text_input(
@@ -69,7 +86,7 @@ else:
6986

7087
:::info
7188

72-
This sample uses Streamlit's [st.cache_resource](https://docs.streamlit.io/develop/concepts/architecture/caching#stcache_resource) with a 1-hour TTL (time-to-live) to cache the database connection across users, sessions, and reruns. The cached connection will automatically expire after 1 hour, ensuring connections don't become stale. Use Streamlit's caching decorators and TTL parameter to implement a caching strategy that works for your use case.
89+
This sample uses Streamlit's [st.cache_resource](https://docs.streamlit.io/develop/concepts/architecture/caching#stcache_resource) with a 300-second TTL (time-to-live) to cache the database connection across users, sessions, and reruns. The cached connection will automatically expire after 1 hour, ensuring connections don't become stale. Use Streamlit's caching decorators and TTL parameter to implement a caching strategy that works for your use case.
7390

7491
:::
7592

docs/docs/streamlit/tables/tables_read.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from databricks.sdk.core import Config
1717
cfg = Config() # Set the DATABRICKS_HOST environment variable when running locally
1818

1919

20-
@st.cache_resource(ttl="1h") # connection is cached
20+
@st.cache_resource(ttl=300, show_spinner=True) # connection is cached
2121
def get_connection(http_path):
2222
return sql.connect(
2323
server_hostname=cfg.host,
@@ -47,7 +47,7 @@ if http_path_input and table_name:
4747

4848
:::info
4949

50-
This sample uses Streamlit's [st.cache_resource](https://docs.streamlit.io/develop/concepts/architecture/caching#stcache_resource) with a 1-hour TTL (time-to-live) to cache the database connection across users, sessions, and reruns. The cached connection will automatically expire after 1 hour, ensuring connections don't become stale. Use Streamlit's caching decorators and TTL parameter to implement a caching strategy that works for your use case.
50+
This sample uses Streamlit's [st.cache_resource](https://docs.streamlit.io/develop/concepts/architecture/caching#stcache_resource) with a 300-second TTL (time-to-live) to cache the database connection across users, sessions, and reruns. The cached connection will automatically expire after 1 hour, ensuring connections don't become stale. Use Streamlit's caching decorators and TTL parameter to implement a caching strategy that works for your use case.
5151

5252
:::
5353

docs/docs/streamlit/visualizations/visualizations_charts.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ warehouses = w.warehouses.list()
2525
warehouse_paths = {wh.name: wh.odbc_params.path for wh in warehouses}
2626

2727
# Connect to SQL warehouse
28-
@st.cache_resource
28+
@st.cache_resource(ttl=300, show_spinner=True)
2929
def get_connection(http_path):
3030
return sql.connect(
3131
server_hostname=cfg.host,

docs/docs/streamlit/visualizations/visualizations_map.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ warehouses = w.warehouses.list()
2525
warehouse_paths = {wh.name: wh.odbc_params.path for wh in warehouses}
2626

2727
# Connect to SQL warehouse
28+
@st.cache_resource(ttl=300, show_spinner=True)
2829
def get_connection(http_path):
2930
return sql.connect(
3031
server_hostname=cfg.host,

streamlit/views/external_connections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55

66

7-
@st.cache_resource
7+
@st.cache_resource(ttl=300, show_spinner=True)
88
def get_client_obo() -> WorkspaceClient:
99
user_token = st.context.headers.get("x-forwarded-access-token")
1010
if not user_token:

0 commit comments

Comments
 (0)