diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 696f20c4..735e4596 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -4,12 +4,17 @@ from pydantic import BaseModel from databricks.sdk import WorkspaceClient from databricks.sdk.service import catalog -from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout +from databricks.sdk.service.sql import ( + Format, + ExecuteStatementRequestOnWaitTimeout, + Disposition, +) from .config import databricks_vars, gcs_vars from .utilities import databricksify_inst_name, SchemaType -from typing import List, Any +from typing import List, Any, Dict from databricks.sdk.errors import DatabricksError + # List of data medallion levels MEDALLION_LEVELS = ["silver", "gold", "bronze"] @@ -196,18 +201,18 @@ def delete_inst(self, inst_name: str) -> None: def fetch_table_data( self, - catalog_name: Any, - schema_name: Any, - table_name: Any, - warehouse_id: Any, + catalog_name: str, + schema_name: str, + table_name: str, + warehouse_id: str, limit: int = 1000, - ) -> List[dict[str, Any]]: + ) -> List[Dict[str, Any]]: """ - Runs a simple SELECT * FROM .. LIMIT - against the specified SQL warehouse, and returns a list of row‐dicts. + Executes a SELECT * query on the specified table within the given catalog and schema, + using the provided SQL warehouse. Returns the result as a list of dictionaries. """ - try: + # Initialize the WorkspaceClient with default authentication client = WorkspaceClient( host=databricks_vars["DATABRICKS_HOST_URL"], google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], @@ -215,43 +220,43 @@ def fetch_table_data( except Exception as e: raise ValueError(f"Failed to initialize WorkspaceClient: {e}") - # 2. Build SQL text - fully_qualified = f"`{catalog_name}`.`{schema_name}`.`{table_name}`" - sql_text = f"SELECT * FROM {fully_qualified} LIMIT {limit}" + # Construct the fully qualified table name + fully_qualified_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`" + sql_query = f"SELECT * FROM {fully_qualified_table} LIMIT {limit}" - # 3. Execute with INLINE+JSON_ARRAY, wait up to 30s, then CANCEL if not done try: - resp = client.statement_execution.execute_statement( + # Execute the SQL statement + response = client.statement_execution.execute_statement( warehouse_id=warehouse_id, - statement=sql_text, - disposition="INLINE", # INLINE disposition - format=Format.JSON_ARRAY, # JSON_ARRAY format - wait_timeout="30s", # up to 30 seconds - on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CANCEL, # cancel if not done + statement=sql_query, + disposition=Disposition.INLINE, # Use Enum member + format=Format.JSON_ARRAY, # Use Enum member + wait_timeout="30s", # Wait up to 30 seconds for execution + on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CANCEL, # Use Enum member ) except DatabricksError as e: raise ValueError(f"Databricks API call failed: {e}") - # 4. Check final state - state = resp.status.state - if state != "SUCCEEDED": - # If there’s an error object, include its message - err = resp.status.error - msg = ( - err.message - if (err is not None and err.message) + # Check if the query execution was successful + if response.status.state != "SUCCEEDED": + error_message = ( + response.status.error.message + if response.status.error else "No additional error info." ) - raise ValueError(f"Query did not succeed (state={state}): {msg}") + raise ValueError( + f"Query did not succeed (state={response.status.state}): {error_message}" + ) - # 5. Ensure manifest and result are present - if resp.manifest is None or resp.manifest.schema is None: + # Validate the presence of the result and schema + if not response.manifest or not response.manifest.schema: raise ValueError("Query succeeded but schema manifest is missing.") - if resp.result is None or resp.result.data_array is None: + if not response.result or not response.result.data_array: raise ValueError("Query succeeded but result data is missing.") - # 6. Extract column names and rows - column_names = [col.name for col in resp.manifest.schema] - data_array = resp.result.data_array + # Extract column names and data rows + column_names = [column.name for column in response.manifest.schema] + data_rows = response.result.data_array - return [dict(zip(column_names, row)) for row in data_array] + # Combine column names with corresponding row values + return [dict(zip(column_names, row)) for row in data_rows] diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 56a01675..4208ce66 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1023,6 +1023,28 @@ def get_upload_url( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) +# Get SHAP Values for Inference +@router.get("/inference/test") +def test() -> Any: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name="dev_sst_02", + schema_name="default", + table_name="test_dataset", + warehouse_id="28e1cbabfe6deb87", + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + # Get SHAP Values for Inference @router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str) def get_top_features(