Skip to content

Commit 6cb33e9

Browse files
committed
Fixes
1 parent ad53710 commit 6cb33e9

File tree

2 files changed

+35
-99
lines changed

2 files changed

+35
-99
lines changed

streamlit/views/tables_edit.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,15 @@
1515

1616
cfg = Config()
1717

18-
# Initialize the client
1918
w = WorkspaceClient()
2019

21-
# List SQL Warehouses
2220
warehouses = w.warehouses.list()
2321

24-
# Create a dictionary to map warehouse names to their paths
2522
warehouse_paths = {wh.name: wh.odbc_params.path for wh in warehouses}
2623

27-
# List catalogs
2824
catalogs = w.catalogs.list()
2925

26+
3027
@st.cache_resource
3128
def get_connection(http_path):
3229
return sql.connect(
@@ -35,36 +32,24 @@ def get_connection(http_path):
3532
credentials_provider=lambda: cfg.authenticate,
3633
)
3734

35+
3836
def read_table(table_name, conn):
3937
with conn.cursor() as cursor:
4038
query = f"SELECT * FROM {table_name}"
4139
cursor.execute(query)
4240
return cursor.fetchall_arrow().to_pandas()
4341

42+
4443
def get_schema_names(catalog_name):
4544
schemas = w.schemas.list(catalog_name=catalog_name)
4645
return [schema.name for schema in schemas]
4746

47+
4848
def get_table_names(catalog_name, schema_name):
4949
tables = w.tables.list(catalog_name=catalog_name, schema_name=schema_name)
5050
return [table.name for table in tables]
5151

5252

53-
@st.cache_resource
54-
def get_connection(http_path):
55-
return sql.connect(
56-
server_hostname=cfg.host,
57-
http_path=http_path,
58-
credentials_provider=lambda: cfg.authenticate,
59-
)
60-
61-
62-
def read_table(table_name: str, conn) -> pd.DataFrame:
63-
with conn.cursor() as cursor:
64-
cursor.execute(f"SELECT * FROM {table_name}")
65-
return cursor.fetchall_arrow().to_pandas()
66-
67-
6853
def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
6954
progress = st.empty()
7055
with conn.cursor() as cursor:
@@ -81,35 +66,30 @@ def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
8166

8267
with tab_a:
8368
http_path_input = st.selectbox(
84-
"Select your Databricks SQL Warehouse:", [""] + list(warehouse_paths.keys())
69+
"Select a SQL warehouse:", [""] + list(warehouse_paths.keys())
8570
)
8671

8772
catalog_name = st.selectbox(
88-
"Select your Catalog:", [""] + [catalog.name for catalog in catalogs]
73+
"Select a catalog:", [""] + [catalog.name for catalog in catalogs]
8974
)
90-
#Message to prompt user to select warehouse and catalog
91-
if http_path_input == "" or catalog_name == "":
92-
st.warning("Select Warehouse and Catalog")
93-
75+
9476
if catalog_name and catalog_name != "":
9577
schema_names = get_schema_names(catalog_name)
96-
schema_name = st.selectbox(
97-
"Select your Schema:", [""] + schema_names
98-
)
99-
if schema_name == "":
100-
st.warning("Select Schema")
78+
schema_name = st.selectbox("Select a schema:", [""] + schema_names)
10179

10280
if catalog_name and catalog_name != "" and schema_name and schema_name != "":
10381
table_names = get_table_names(catalog_name, schema_name)
104-
table_name = st.selectbox(
105-
"Select your Table:", [""] + table_names
106-
)
107-
if table_name == "":
108-
st.warning("Select Table")
109-
82+
table_name = st.selectbox("Select a table:", [""] + table_names)
83+
11084
in_table_name = f"{catalog_name}.{schema_name}.{table_name}"
11185

112-
if http_path_input and table_name and catalog_name and schema_name and table_name != "":
86+
if (
87+
http_path_input
88+
and table_name
89+
and catalog_name
90+
and schema_name
91+
and table_name != ""
92+
):
11393
http_path = warehouse_paths[http_path_input]
11494
conn = get_connection(http_path)
11595
original_df = read_table(in_table_name, conn)
@@ -119,8 +99,6 @@ def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
11999
if not df_diff.empty:
120100
if st.button("Save changes"):
121101
insert_overwrite_table(in_table_name, edited_df, conn)
122-
# else:
123-
# st.warning("Provide both the warehouse path and a table name to load data.")
124102

125103

126104
with tab_b:
@@ -131,6 +109,7 @@ def insert_overwrite_table(table_name: str, df: pd.DataFrame, conn):
131109
from databricks import sql
132110
from databricks.sdk.core import Config
133111
112+
134113
cfg = Config() # Set the DATABRICKS_HOST environment variable when running locally
135114
136115

streamlit/views/tables_read.py

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,15 @@
1111

1212
cfg = Config()
1313

14-
# Initialize the client
1514
w = WorkspaceClient()
1615

17-
# List SQL Warehouses
1816
warehouses = w.warehouses.list()
1917

20-
# Create a dictionary to map warehouse names to their paths
2118
warehouse_paths = {wh.name: wh.odbc_params.path for wh in warehouses}
2219

23-
# List catalogs
2420
catalogs = w.catalogs.list()
2521

22+
2623
@st.cache_resource
2724
def get_connection(http_path):
2825
return sql.connect(
@@ -31,58 +28,50 @@ def get_connection(http_path):
3128
credentials_provider=lambda: cfg.authenticate,
3229
)
3330

31+
3432
def read_table(table_name, conn):
3533
with conn.cursor() as cursor:
3634
query = f"SELECT * FROM {table_name}"
3735
cursor.execute(query)
3836
return cursor.fetchall_arrow().to_pandas()
3937

38+
4039
def get_schema_names(catalog_name):
4140
schemas = w.schemas.list(catalog_name=catalog_name)
4241
return [schema.name for schema in schemas]
4342

43+
4444
def get_table_names(catalog_name, schema_name):
4545
tables = w.tables.list(catalog_name=catalog_name, schema_name=schema_name)
4646
return [table.name for table in tables]
4747

48+
4849
tab_a, tab_b, tab_c = st.tabs(["**Try it**", "**Code snippet**", "**Requirements**"])
4950

5051
with tab_a:
5152
http_path_input = st.selectbox(
52-
"Select your Databricks SQL Warehouse:", [""] + list(warehouse_paths.keys())
53+
"Select a SQL warehouse:", [""] + list(warehouse_paths.keys())
5354
)
5455

5556
catalog_name = st.selectbox(
56-
"Select your Catalog:", [""] + [catalog.name for catalog in catalogs]
57+
"Select a catalog:", [""] + [catalog.name for catalog in catalogs]
5758
)
5859

5960
if http_path_input == "" or catalog_name == "":
6061
st.warning("Select Warehouse and Catalog")
6162

6263
if catalog_name and catalog_name != "":
6364
schema_names = get_schema_names(catalog_name)
64-
schema_name = st.selectbox(
65-
"Select your Schema:", [""] + schema_names
66-
)
67-
if schema_name == "":
68-
st.warning("Select Schema")
65+
schema_name = st.selectbox("Select a schema:", [""] + schema_names)
6966

7067
if catalog_name and catalog_name != "" and schema_name and schema_name != "":
7168
table_names = get_table_names(catalog_name, schema_name)
72-
table_name = st.selectbox(
73-
"Select your Table:", [""] + table_names
74-
)
75-
76-
if table_name == "":
77-
st.warning("Select Table")
69+
table_name = st.selectbox("Select a table:", [""] + table_names)
7870

7971
if http_path_input and table_name and table_name != "":
8072
http_path = warehouse_paths[http_path_input]
8173
conn = get_connection(http_path)
82-
#info_placeholder = st.empty()
83-
st.info(f"Running Select on {catalog_name}.{schema_name}.{table_name}")
8474
df = read_table(f"{catalog_name}.{schema_name}.{table_name}", conn)
85-
#info_placeholder.empty() # Clear the info message
8675
st.dataframe(df)
8776

8877

@@ -92,23 +81,12 @@ def get_table_names(catalog_name, schema_name):
9281
import streamlit as st
9382
from databricks import sql
9483
from databricks.sdk.core import Config
95-
from databricks.sdk import WorkspaceClient
96-
97-
cfg = Config()
98-
99-
# Initialize the client
100-
w = WorkspaceClient()
10184
102-
# List SQL Warehouses
103-
warehouses = w.warehouses.list()
10485
105-
# Create a dictionary to map warehouse names to their paths
106-
warehouse_paths = {wh.name: wh.odbc_params.http_path for wh in warehouses}
86+
cfg = Config() # Set the DATABRICKS_HOST environment variable when running locally
10787
108-
# List catalogs
109-
catalogs = w.catalogs.list()
11088
111-
@st.cache_resource
89+
@st.cache_resource # connection is cached
11290
def get_connection(http_path):
11391
return sql.connect(
11492
server_hostname=cfg.host,
@@ -122,37 +100,16 @@ def read_table(table_name, conn):
122100
cursor.execute(query)
123101
return cursor.fetchall_arrow().to_pandas()
124102
125-
def get_schema_names(catalog_name):
126-
schemas = w.schemas.list(catalog_name=catalog_name)
127-
return [schema.name for schema in schemas]
128-
129-
def get_table_names(catalog_name, schema_name):
130-
tables = w.tables.list(catalog_name=catalog_name, schema_name=schema_name)
131-
return [table.name for table in tables]
132-
133-
http_path_input = st.selectbox(
134-
"Select your Databricks SQL Warehouse", list(warehouse_paths.keys()), placeholder="select warehouse"
103+
http_path_input = st.text_input(
104+
"Enter your Databricks HTTP Path:", placeholder="/sql/1.0/warehouses/xxxxxx"
135105
)
136106
137-
catalog_name = st.selectbox(
138-
"Select your Catalog:", [catalog.name for catalog in catalogs]
107+
table_name = st.text_input(
108+
"Specify a Unity Catalog table name:", placeholder="catalog.schema.table"
139109
)
140110
141-
if catalog_name:
142-
schema_names = get_schema_names(catalog_name)
143-
schema_name = st.selectbox(
144-
"Select your Schema:", schema_names
145-
)
146-
147-
if catalog_name and schema_name:
148-
table_names = get_table_names(catalog_name, schema_name)
149-
table_name = st.selectbox(
150-
"Select your Table:", table_names
151-
)
152-
153111
if http_path_input and table_name:
154-
http_path = warehouse_paths[http_path_input]
155-
conn = get_connection(http_path)
112+
conn = get_connection(http_path_input)
156113
df = read_table(table_name, conn)
157114
st.dataframe(df)
158115
"""
@@ -179,4 +136,4 @@ def get_table_names(catalog_name, schema_name):
179136
* [Databricks SDK](https://pypi.org/project/databricks-sdk/) - `databricks-sdk`
180137
* [Databricks SQL Connector](https://pypi.org/project/databricks-sql-connector/) - `databricks-sql-connector`
181138
* [Streamlit](https://pypi.org/project/streamlit/) - `streamlit`
182-
""")
139+
""")

0 commit comments

Comments
 (0)