Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 41 additions & 36 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from pydantic import BaseModel
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout
from databricks.sdk.service.sql import (
Format,
ExecuteStatementRequestOnWaitTimeout,
Disposition,
)
from .config import databricks_vars, gcs_vars
from .utilities import databricksify_inst_name, SchemaType
from typing import List, Any
from typing import List, Any, Dict
from databricks.sdk.errors import DatabricksError


# List of data medallion levels
MEDALLION_LEVELS = ["silver", "gold", "bronze"]

Expand Down Expand Up @@ -196,62 +201,62 @@ def delete_inst(self, inst_name: str) -> None:

def fetch_table_data(
self,
catalog_name: Any,
schema_name: Any,
table_name: Any,
warehouse_id: Any,
catalog_name: str,
schema_name: str,
table_name: str,
warehouse_id: str,
limit: int = 1000,
) -> List[dict[str, Any]]:
) -> List[Dict[str, Any]]:
"""
Runs a simple SELECT * FROM <catalog>.<schema>.<table> LIMIT <limit>
against the specified SQL warehouse, and returns a list of row‐dicts.
Executes a SELECT * query on the specified table within the given catalog and schema,
using the provided SQL warehouse. Returns the result as a list of dictionaries.
"""

try:
# Initialize the WorkspaceClient with default authentication
client = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
raise ValueError(f"Failed to initialize WorkspaceClient: {e}")

# 2. Build SQL text
fully_qualified = f"`{catalog_name}`.`{schema_name}`.`{table_name}`"
sql_text = f"SELECT * FROM {fully_qualified} LIMIT {limit}"
# Construct the fully qualified table name
fully_qualified_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`"
sql_query = f"SELECT * FROM {fully_qualified_table} LIMIT {limit}"

# 3. Execute with INLINE+JSON_ARRAY, wait up to 30s, then CANCEL if not done
try:
resp = client.statement_execution.execute_statement(
# Execute the SQL statement
response = client.statement_execution.execute_statement(
warehouse_id=warehouse_id,
statement=sql_text,
disposition="INLINE", # INLINE disposition
format=Format.JSON_ARRAY, # JSON_ARRAY format
wait_timeout="30s", # up to 30 seconds
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CANCEL, # cancel if not done
statement=sql_query,
disposition=Disposition.INLINE, # Use Enum member
format=Format.JSON_ARRAY, # Use Enum member
wait_timeout="30s", # Wait up to 30 seconds for execution
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CANCEL, # Use Enum member
)
except DatabricksError as e:
raise ValueError(f"Databricks API call failed: {e}")

# 4. Check final state
state = resp.status.state
if state != "SUCCEEDED":
# If there’s an error object, include its message
err = resp.status.error
msg = (
err.message
if (err is not None and err.message)
# Check if the query execution was successful
if response.status.state != "SUCCEEDED":
error_message = (
response.status.error.message
if response.status.error
else "No additional error info."
)
raise ValueError(f"Query did not succeed (state={state}): {msg}")
raise ValueError(
f"Query did not succeed (state={response.status.state}): {error_message}"
)

# 5. Ensure manifest and result are present
if resp.manifest is None or resp.manifest.schema is None:
# Validate the presence of the result and schema
if not response.manifest or not response.manifest.schema:
raise ValueError("Query succeeded but schema manifest is missing.")
if resp.result is None or resp.result.data_array is None:
if not response.result or not response.result.data_array:
raise ValueError("Query succeeded but result data is missing.")

# 6. Extract column names and rows
column_names = [col.name for col in resp.manifest.schema]
data_array = resp.result.data_array
# Extract column names and data rows
column_names = [column.name for column in response.manifest.schema]
data_rows = response.result.data_array

return [dict(zip(column_names, row)) for row in data_array]
# Combine column names with corresponding row values
return [dict(zip(column_names, row)) for row in data_rows]
22 changes: 22 additions & 0 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,28 @@ def get_upload_url(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get SHAP Values for Inference
@router.get("/inference/test")
def test() -> Any:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name="dev_sst_02",
schema_name="default",
table_name="test_dataset",
warehouse_id="28e1cbabfe6deb87",
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get SHAP Values for Inference
@router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str)
def get_top_features(
Expand Down
Loading