diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 31e53046..1f60c4ed 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -53,6 +53,10 @@ def get_api_key( ) +def check_creds(username: str, password: str) -> bool: + return username == env_vars["USERNAME"] and password == env_vars["PASSWORD"] + + def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel Generates hashes that start with 2y. The hashing scheme recognizes both.""" diff --git a/src/webapp/config.py b/src/webapp/config.py index 29f6aad0..7d62388b 100644 --- a/src/webapp/config.py +++ b/src/webapp/config.py @@ -13,6 +13,10 @@ "API_KEY_ISSUERS": [], "INITIAL_API_KEY": "", "INITIAL_API_KEY_ID": "", + "CATALOG_NAME": "", + "SQL_WAREHOUSE_ID": "", + "USERNAME": "", + "PASSWORD": "", } # The INSTANCE_HOST is the private IP of CLoudSQL instance e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 1761891b..9ba960e6 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -4,9 +4,11 @@ from pydantic import BaseModel from databricks.sdk import WorkspaceClient from databricks.sdk.service import catalog - +from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout from .config import databricks_vars, gcs_vars from .utilities import databricksify_inst_name, SchemaType +from typing import List, Any +import time # List of data medallion levels MEDALLION_LEVELS = ["silver", "gold", "bronze"] @@ -191,3 +193,65 @@ def delete_inst(self, inst_name: str) -> None: full_name=f"{cat_name}.{db_inst_name}_{medallion}.{table}" ) w.schemas.delete(full_name=f"{cat_name}.{db_inst_name}_{medallion}") + + def fetch_table_data( + self, + catalog_name: Any, + schema_name: Any, + table_name: Any, + warehouse_id: Any, + limit: int = 1000, + ) -> List[dict[str, Any]]: + """ + Runs a simple SELECT * FROM .. LIMIT + against the specified SQL warehouse, and returns a list of row‐dicts. + """ + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + if not w: + raise ValueError( + "fetch_table_data(): could not initialize WorkspaceClient." + ) + + fq_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`" + sql = f"SELECT * FROM {fq_table} LIMIT {limit}" + + resp = w.statement_execution.execute_statement( + warehouse_id=warehouse_id, + statement=sql, + format=Format.JSON_ARRAY, + wait_timeout="10s", + on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, + ) + + status = getattr(resp, "status", None) + if status and status.state == "SUCCEEDED" and getattr(resp, "result", None): + # resp.results is a list of row‐arrays, resp.schema is a list of column metadata + column_names = [col.name for col in resp.manifest.schema] + rows = resp.result.data_array + else: + # A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set. + stmt_id = getattr(resp, "statement_id", None) + if not stmt_id: + raise ValueError( + f"fetch_table_data(): unexpected response state: {resp}" + ) + + # B. Poll until the statement succeeds (or fails/cancels) + status = resp.status.state if getattr(resp, "status", None) else None + while status not in ("SUCCEEDED", "FAILED", "CANCELED"): + time.sleep(1) + resp2 = w.statement_execution.get_statement(statement_id=stmt_id) + status = resp2.status.state if getattr(resp2, "status", None) else None + resp = resp2 + if status != "SUCCEEDED": + raise ValueError(f"fetch_table_data(): query ended with state {status}") + + # C. At this point, resp holds the final manifest and first chunk + column_names = [col.name for col in resp.manifest.schema] + rows = resp.result.data_array + + # Transform each row (a list of values) into a dict + return [dict(zip(column_names, row)) for row in rows] diff --git a/src/webapp/main.py b/src/webapp/main.py index 0d2f90e0..47d14964 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -6,6 +6,7 @@ import secrets from fastapi import FastAPI, Depends, HTTPException, status, Security from fastapi.responses import FileResponse +from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel from sqlalchemy.future import select from sqlalchemy import update @@ -37,6 +38,7 @@ create_access_token, get_api_key, get_api_key_hash, + check_creds, ) # Set the logging @@ -99,13 +101,16 @@ async def access_token_from_api_key( ) -> Token: """Generate a token from an API key.""" local_session.set(sql_session) + user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="API key not valid", + detail="Invalid API key and credentials", headers={"WWW-Authenticate": "X-API-KEY"}, ) + access_token_expires = timedelta( minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]) ) diff --git a/src/webapp/main_test.py b/src/webapp/main_test.py index 9c6e99e3..df4b3dfd 100644 --- a/src/webapp/main_test.py +++ b/src/webapp/main_test.py @@ -13,6 +13,7 @@ get_session, ApiKeyTable, ) +from unittest.mock import patch from .authn import get_password_hash, get_api_key_hash from .test_helper import ( DATAKINDER, diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index f72d701b..56a01675 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -12,6 +12,7 @@ import os import logging from sqlalchemy.exc import IntegrityError +from ..config import env_vars from ..utilities import ( has_access_to_inst_or_err, @@ -31,8 +32,10 @@ local_session, BatchTable, FileTable, + InstTable, ) +from ..databricks import DatabricksControl from ..gcsdbutils import update_db_from_bucket from ..gcsutil import StorageControl @@ -1018,3 +1021,220 @@ def get_upload_url( except ValueError as ve: # Return a 400 error with the specific message from ValueError raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +# Get SHAP Values for Inference +@router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str) +def get_top_features( + inst_id: str, + run_id: str, + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + # has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_inference_{run_id}_features_with_most_impact", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +# Get SHAP Values for Inference +@router.get("/{inst_id}/inference/support-overview/{run_id}", response_model=str) +def get_support_overview( + inst_id: str, + run_id: str, + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + # has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_inference_{run_id}_support_overview", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +@router.get("/{inst_id}/inference/feature_value/{run_id}", response_model=str) +def get_feature_value( + inst_id: str, + run_id: str, + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + # has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_inference_{run_id}_shap_feature_importance", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +@router.get("/{inst_id}/training/confusion_matrix/{run_id}", response_model=str) +def get_confusion_matrix( + inst_id: str, + run_id: str, + ##current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + # has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_training_{run_id}_confusion_matrix", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +@router.get("/{inst_id}/training/roc_curve/{run_id}", response_model=str) +def get_roc_curve( + inst_id: str, + run_id: str, + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + # has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_training_{run_id}_roc_curve", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) diff --git a/src/worker/databricks.py b/src/worker/databricks.py deleted file mode 100644 index dab643a7..00000000 --- a/src/worker/databricks.py +++ /dev/null @@ -1,52 +0,0 @@ -import requests -from databricks import sql -from google.auth.transport.requests import Request -from google.oauth2 import id_token -from typing import Any - - -class DatabricksSQLConnector: - """ - Helper to get a Databricks SQL connection via GCP service account identity token. - """ - - def __init__( - self, databricks_host: str, http_path: str, client_id: str, client_secret: str - ): - self.databricks_host = databricks_host - self.http_path = http_path - self.client_id = client_id - self.client_secret = client_secret - self.token_exchange_url = f"{self.databricks_host}/oidc/v1/token" - - def _get_google_id_token( - self, audience: str = "https://accounts.google.com" - ) -> Any: - """Fetch a GCP identity token for the service account.""" - return id_token.fetch_id_token(Request(), audience) - - def _exchange_token_for_databricks(self, subject_token: str) -> Any: - """Exchange GCP identity token for Databricks access token.""" - response = requests.post( - self.token_exchange_url, - data={ - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "subject_token": subject_token, - "subject_token_type": "urn:ietf:params:oauth:token-type:id_token", - "scope": "openid offline_access", - }, - auth=(self.client_id, self.client_secret), - ) - if response.status_code != 200: - raise RuntimeError(f"Databricks token exchange failed: {response.text}") - return response.json()["access_token"] - - def get_sql_connection(self) -> Any: - """Authenticate and return a Databricks SQL connection.""" - id_token_str = self._get_google_id_token() - access_token = self._exchange_token_for_databricks(id_token_str) - return sql.connect( - server_hostname=self.databricks_host.replace("https://", ""), - http_path=self.http_path, - access_token=access_token, - ) diff --git a/src/worker/main.py b/src/worker/main.py index 6e0f125f..85ef4078 100644 --- a/src/worker/main.py +++ b/src/worker/main.py @@ -18,10 +18,9 @@ transfer_file, sftp_file_to_gcs_helper, validate_sftp_file, + confusion_matrix_table, ) -from .databricks import DatabricksSQLConnector - from .config import sftp_vars, env_vars, startup_env_vars from .authn import Token, get_current_username, check_creds, create_access_token from datetime import timedelta @@ -233,30 +232,19 @@ async def execute_pdp_pull( } -# Get SHAP Values for Inference -@app.get("/{inst_id}/top-features/{run_id}", response_model=str) -def get_top_features( - inst_id: str, +@app.get("/confusion-matrix-test") +async def confusion_matrix_test( run_id: str, - current_username: Annotated[str, Depends(get_current_username)], + inst_id: str, ) -> Any: - """Returns a signed URL for uploading data to a specific institution.""" - # raise error at this level instead bc otherwise it's getting wrapped as a 200 - - try: - connector = DatabricksSQLConnector( - databricks_host=env_vars["DATABRICKS_HOST"], - http_path=env_vars["DATABRICKS_SQL_HTTP_PATH"], - client_id=env_vars["DATABRICKS_CLIENT_ID"], - client_secret=env_vars["DATABRICKS_CLIENT_SECRET"], - ) + """Performs the PDP pull of the file.""" - conn = connector.get_sql_connection() - cursor = conn.cursor() - cursor.execute( - "SELECT * FROM staging_sst_01.metropolitan_state_uni_of_denver_gold.sample_inference_66d9716883be4b01a4ea4de82f2d09d5_features_with_most_impact LIMIT 10" - ) - print(cursor.fetchall()) - except ValueError as ve: - # Return a 400 error with the specific message from ValueError - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + result = confusion_matrix_table( + institution_id=inst_id, + webapp_url=env_vars["WEBAPP_URL"], + backend_api_key=env_vars["BACKEND_API_KEY"], + run_id=run_id, + ) + + # Aggregate results to return + return result diff --git a/src/worker/utilities.py b/src/worker/utilities.py index 23c34447..b23c9d6f 100644 --- a/src/worker/utilities.py +++ b/src/worker/utilities.py @@ -583,3 +583,43 @@ def validate_sftp_file( except Exception as e: logger.exception("<<<< ???? Exception during file validation request.") return {"error": f"Exception during validation: {e}"} + + +def confusion_matrix_table( + institution_id: str, webapp_url: str, backend_api_key: str, run_id: str +) -> Any: + """ + Sends a POST request to validate an SFTP file. + + Args: + institution_id (str): The ID of the institution for which the file validation is intended. + file_name (str): The name of the file to be validated. + access_token (str): The bearer token used for authorization. + + Returns: + str: The server's response to the validation request. + """ + access_token = get_token(backend_api_key=backend_api_key, webapp_url=webapp_url) + if not access_token: + logger.error("<<<< ???? Access token not found in the response.") + return "Access token not found in the response." + + url = f"{webapp_url}/api/v1/institutions//{institution_id}/training/confusion_matrix/{run_id}" + headers = {"accept": "application/json", "Authorization": f"Bearer {access_token}"} + + logger.debug(f">>>> Retrieving confusion matric table from {url}") + + try: + response = requests.get(url, headers=headers) + + if response.status_code == 200: + logger.info(">>>> File validation successful.") + return response.json() + + error_msg = f"Failed to validate file: {response.status_code} {response.text}" + logger.error(f"<<<< ???? {error_msg}") + return {"error": error_msg} + + except Exception as e: + logger.exception("<<<< ???? Exception during file validation request.") + return {"error": f"Exception during validation: {e}"}