From 53ebd9258d4381a314ea8d4d2a81a51420971032 Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 15:35:51 -0500 Subject: [PATCH 01/26] feat: added all FE tables --- src/webapp/config.py | 2 + src/webapp/databricks.py | 67 +++++++++++ src/webapp/routers/data.py | 220 +++++++++++++++++++++++++++++++++++++ src/worker/databricks.py | 52 --------- src/worker/main.py | 40 +++---- src/worker/utilities.py | 40 +++++++ 6 files changed, 343 insertions(+), 78 deletions(-) delete mode 100644 src/worker/databricks.py diff --git a/src/webapp/config.py b/src/webapp/config.py index 29f6aad0..db6df647 100644 --- a/src/webapp/config.py +++ b/src/webapp/config.py @@ -13,6 +13,8 @@ "API_KEY_ISSUERS": [], "INITIAL_API_KEY": "", "INITIAL_API_KEY_ID": "", + "CATALOG_NAME": "", + "SQL_WAREHOUSE_ID": "", } # The INSTANCE_HOST is the private IP of CLoudSQL instance e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 1761891b..4b8aa4ce 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -7,6 +7,8 @@ from .config import databricks_vars, gcs_vars from .utilities import databricksify_inst_name, SchemaType +from typing import List, Any +import time # List of data medallion levels MEDALLION_LEVELS = ["silver", "gold", "bronze"] @@ -191,3 +193,68 @@ def delete_inst(self, inst_name: str) -> None: full_name=f"{cat_name}.{db_inst_name}_{medallion}.{table}" ) w.schemas.delete(full_name=f"{cat_name}.{db_inst_name}_{medallion}") + + def fetch_table_data( + self, + catalog_name: str, + schema_name: str, + table_name: str, + warehouse_id: str, + limit: int = 1000, + ) -> List[dict[str, Any]]: + """ + Runs a simple SELECT * FROM .. LIMIT + against the specified SQL warehouse, and returns a list of row‐dicts. + """ + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + if not w: + raise ValueError( + "fetch_table_data(): could not initialize WorkspaceClient." + ) + + fq_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`" + sql = f"SELECT * FROM {fq_table} LIMIT {limit}" + + resp = w.statement_execution.execute_statement( + warehouse_id=warehouse_id, + statement=sql, + format="JSON_ARRAY", + wait_timeout=10, + on_wait_timeout="CONTINUE", + ) + + if ( + getattr(resp, "status", None) + and resp.status.state == "SUCCEEDED" + and getattr(resp, "results", None) + ): + # resp.results is a list of row‐arrays, resp.schema is a list of column metadata + column_names = [col.name for col in resp.schema] + rows = resp.results + else: + # A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set. + stmt_id = getattr(resp, "statement_id", None) + if not stmt_id: + raise ValueError( + f"fetch_table_data(): unexpected response state: {resp}" + ) + + # B. Poll until the statement succeeds (or fails/cancels) + status = resp.status.state if getattr(resp, "status", None) else None + while status not in ("SUCCEEDED", "FAILED", "CANCELED"): + time.sleep(1) + resp2 = w.statement_execution.get_statement(statement_id=stmt_id) + status = resp2.status.state + resp = resp2 + if status != "SUCCEEDED": + raise ValueError(f"fetch_table_data(): query ended with state {status}") + + # C. At this point, resp holds the final manifest and first chunk + column_names = [col.name for col in resp.schema] + rows = resp.results + + # Transform each row (a list of values) into a dict + return [dict(zip(column_names, row)) for row in rows] diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index f72d701b..17e35301 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -12,6 +12,7 @@ import os import logging from sqlalchemy.exc import IntegrityError +from ..config import env_vars from ..utilities import ( has_access_to_inst_or_err, @@ -31,8 +32,10 @@ local_session, BatchTable, FileTable, + InstTable, ) +from ..databricks import DatabricksControl from ..gcsdbutils import update_db_from_bucket from ..gcsutil import StorageControl @@ -1018,3 +1021,220 @@ def get_upload_url( except ValueError as ve: # Return a 400 error with the specific message from ValueError raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +# Get SHAP Values for Inference +@router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str) +def get_top_features( + inst_id: str, + run_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_inference_{run_id}_features_with_most_impact", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +# Get SHAP Values for Inference +@router.get("/{inst_id}/inference/support-overview/{run_id}", response_model=str) +def get_support_overview( + inst_id: str, + run_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_inference_{run_id}_support_overview", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +@router.get("/{inst_id}/inference/feature_value/{run_id}", response_model=str) +def get_feature_value( + inst_id: str, + run_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_inference_{run_id}_shap_feature_importance", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +@router.get("/{inst_id}/training/confusion_matrix/{run_id}", response_model=str) +def get_confusion_matrix( + inst_id: str, + run_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_training_{run_id}_confusion_matrix", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +@router.get("/{inst_id}/training/roc_curve/{run_id}", response_model=str) +def get_roc_curve( + inst_id: str, + run_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + schema_name=f"{query_result[0][0].name}_silver", + table_name=f"sample_training_{run_id}_roc_curve", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) diff --git a/src/worker/databricks.py b/src/worker/databricks.py deleted file mode 100644 index dab643a7..00000000 --- a/src/worker/databricks.py +++ /dev/null @@ -1,52 +0,0 @@ -import requests -from databricks import sql -from google.auth.transport.requests import Request -from google.oauth2 import id_token -from typing import Any - - -class DatabricksSQLConnector: - """ - Helper to get a Databricks SQL connection via GCP service account identity token. - """ - - def __init__( - self, databricks_host: str, http_path: str, client_id: str, client_secret: str - ): - self.databricks_host = databricks_host - self.http_path = http_path - self.client_id = client_id - self.client_secret = client_secret - self.token_exchange_url = f"{self.databricks_host}/oidc/v1/token" - - def _get_google_id_token( - self, audience: str = "https://accounts.google.com" - ) -> Any: - """Fetch a GCP identity token for the service account.""" - return id_token.fetch_id_token(Request(), audience) - - def _exchange_token_for_databricks(self, subject_token: str) -> Any: - """Exchange GCP identity token for Databricks access token.""" - response = requests.post( - self.token_exchange_url, - data={ - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "subject_token": subject_token, - "subject_token_type": "urn:ietf:params:oauth:token-type:id_token", - "scope": "openid offline_access", - }, - auth=(self.client_id, self.client_secret), - ) - if response.status_code != 200: - raise RuntimeError(f"Databricks token exchange failed: {response.text}") - return response.json()["access_token"] - - def get_sql_connection(self) -> Any: - """Authenticate and return a Databricks SQL connection.""" - id_token_str = self._get_google_id_token() - access_token = self._exchange_token_for_databricks(id_token_str) - return sql.connect( - server_hostname=self.databricks_host.replace("https://", ""), - http_path=self.http_path, - access_token=access_token, - ) diff --git a/src/worker/main.py b/src/worker/main.py index 6e0f125f..85ef4078 100644 --- a/src/worker/main.py +++ b/src/worker/main.py @@ -18,10 +18,9 @@ transfer_file, sftp_file_to_gcs_helper, validate_sftp_file, + confusion_matrix_table, ) -from .databricks import DatabricksSQLConnector - from .config import sftp_vars, env_vars, startup_env_vars from .authn import Token, get_current_username, check_creds, create_access_token from datetime import timedelta @@ -233,30 +232,19 @@ async def execute_pdp_pull( } -# Get SHAP Values for Inference -@app.get("/{inst_id}/top-features/{run_id}", response_model=str) -def get_top_features( - inst_id: str, +@app.get("/confusion-matrix-test") +async def confusion_matrix_test( run_id: str, - current_username: Annotated[str, Depends(get_current_username)], + inst_id: str, ) -> Any: - """Returns a signed URL for uploading data to a specific institution.""" - # raise error at this level instead bc otherwise it's getting wrapped as a 200 - - try: - connector = DatabricksSQLConnector( - databricks_host=env_vars["DATABRICKS_HOST"], - http_path=env_vars["DATABRICKS_SQL_HTTP_PATH"], - client_id=env_vars["DATABRICKS_CLIENT_ID"], - client_secret=env_vars["DATABRICKS_CLIENT_SECRET"], - ) + """Performs the PDP pull of the file.""" - conn = connector.get_sql_connection() - cursor = conn.cursor() - cursor.execute( - "SELECT * FROM staging_sst_01.metropolitan_state_uni_of_denver_gold.sample_inference_66d9716883be4b01a4ea4de82f2d09d5_features_with_most_impact LIMIT 10" - ) - print(cursor.fetchall()) - except ValueError as ve: - # Return a 400 error with the specific message from ValueError - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + result = confusion_matrix_table( + institution_id=inst_id, + webapp_url=env_vars["WEBAPP_URL"], + backend_api_key=env_vars["BACKEND_API_KEY"], + run_id=run_id, + ) + + # Aggregate results to return + return result diff --git a/src/worker/utilities.py b/src/worker/utilities.py index 23c34447..5d17559d 100644 --- a/src/worker/utilities.py +++ b/src/worker/utilities.py @@ -583,3 +583,43 @@ def validate_sftp_file( except Exception as e: logger.exception("<<<< ???? Exception during file validation request.") return {"error": f"Exception during validation: {e}"} + + +def confusion_matrix_table( + institution_id: int, webapp_url: str, backend_api_key: str, run_id: str +) -> Any: + """ + Sends a POST request to validate an SFTP file. + + Args: + institution_id (str): The ID of the institution for which the file validation is intended. + file_name (str): The name of the file to be validated. + access_token (str): The bearer token used for authorization. + + Returns: + str: The server's response to the validation request. + """ + access_token = get_token(backend_api_key=backend_api_key, webapp_url=webapp_url) + if not access_token: + logger.error("<<<< ???? Access token not found in the response.") + return "Access token not found in the response." + + url = f"{webapp_url}/api/v1/institutions//{institution_id}/training/confusion_matrix/{run_id}" + headers = {"accept": "application/json", "Authorization": f"Bearer {access_token}"} + + logger.debug(f">>>> Retrieving confusion matric table from {url}") + + try: + response = requests.get(url, headers=headers) + + if response.status_code == 200: + logger.info(">>>> File validation successful.") + return response.json() + + error_msg = f"Failed to validate file: {response.status_code} {response.text}" + logger.error(f"<<<< ???? {error_msg}") + return {"error": error_msg} + + except Exception as e: + logger.exception("<<<< ???? Exception during file validation request.") + return {"error": f"Exception during validation: {e}"} From 1c160fb05386940a19ee19a63a841d5c6320de77 Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 15:41:16 -0500 Subject: [PATCH 02/26] feat: added all FE tables --- src/webapp/databricks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 4b8aa4ce..a804dbd3 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -196,10 +196,10 @@ def delete_inst(self, inst_name: str) -> None: def fetch_table_data( self, - catalog_name: str, - schema_name: str, - table_name: str, - warehouse_id: str, + catalog_name: Any, + schema_name: Any, + table_name: Any, + warehouse_id: Any, limit: int = 1000, ) -> List[dict[str, Any]]: """ From 549f49fe16c1d54fdbebb3d956a23f1426fb55da Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 16:29:19 -0500 Subject: [PATCH 03/26] feat: added all FE tables --- src/webapp/databricks.py | 2 +- src/worker/utilities.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index a804dbd3..32f1a029 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -222,7 +222,7 @@ def fetch_table_data( warehouse_id=warehouse_id, statement=sql, format="JSON_ARRAY", - wait_timeout=10, + wait_timeout="10s", on_wait_timeout="CONTINUE", ) diff --git a/src/worker/utilities.py b/src/worker/utilities.py index 5d17559d..b23c9d6f 100644 --- a/src/worker/utilities.py +++ b/src/worker/utilities.py @@ -586,7 +586,7 @@ def validate_sftp_file( def confusion_matrix_table( - institution_id: int, webapp_url: str, backend_api_key: str, run_id: str + institution_id: str, webapp_url: str, backend_api_key: str, run_id: str ) -> Any: """ Sends a POST request to validate an SFTP file. From d69ed25ce41575ca49f03f128dabf31bd2974a93 Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 16:36:37 -0500 Subject: [PATCH 04/26] feat: added all FE tables --- src/webapp/databricks.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 32f1a029..2d21531d 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from databricks.sdk import WorkspaceClient from databricks.sdk.service import catalog - +from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout from .config import databricks_vars, gcs_vars from .utilities import databricksify_inst_name, SchemaType from typing import List, Any @@ -221,19 +221,16 @@ def fetch_table_data( resp = w.statement_execution.execute_statement( warehouse_id=warehouse_id, statement=sql, - format="JSON_ARRAY", + format=Format.JSON_ARRAY wait_timeout="10s", - on_wait_timeout="CONTINUE", + on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, ) - if ( - getattr(resp, "status", None) - and resp.status.state == "SUCCEEDED" - and getattr(resp, "results", None) - ): + status = getattr(resp, "status", None) + if status and status.state == "SUCCEEDED" and getattr(resp, "result", None): # resp.results is a list of row‐arrays, resp.schema is a list of column metadata - column_names = [col.name for col in resp.schema] - rows = resp.results + column_names = [col.name for col in resp.result.schema] + rows = resp.result.data_array else: # A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set. stmt_id = getattr(resp, "statement_id", None) @@ -247,14 +244,14 @@ def fetch_table_data( while status not in ("SUCCEEDED", "FAILED", "CANCELED"): time.sleep(1) resp2 = w.statement_execution.get_statement(statement_id=stmt_id) - status = resp2.status.state + status = resp2.status.state if getattr(resp2, "status", None) else None resp = resp2 if status != "SUCCEEDED": raise ValueError(f"fetch_table_data(): query ended with state {status}") # C. At this point, resp holds the final manifest and first chunk - column_names = [col.name for col in resp.schema] - rows = resp.results + column_names = [col.name for col in resp.result.schema] + rows = resp.result.data_array # Transform each row (a list of values) into a dict return [dict(zip(column_names, row)) for row in rows] From 63e7ce23d71292d5b89db1ee31d04783c71f2dcc Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 16:38:48 -0500 Subject: [PATCH 05/26] feat: added all FE tables --- src/webapp/databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 2d21531d..daf5a8e9 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from databricks.sdk import WorkspaceClient from databricks.sdk.service import catalog -from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout +from databricks.sdk.service.sql import Format from .config import databricks_vars, gcs_vars from .utilities import databricksify_inst_name, SchemaType from typing import List, Any @@ -223,7 +223,7 @@ def fetch_table_data( statement=sql, format=Format.JSON_ARRAY wait_timeout="10s", - on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, + on_wait_timeout="CONTINUE", ) status = getattr(resp, "status", None) From 98f3f97b46a1b5fd96790ba78ef4f6e0f4283b4c Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 16:40:56 -0500 Subject: [PATCH 06/26] feat: added all FE tables --- src/webapp/databricks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index daf5a8e9..7b8bc883 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from databricks.sdk import WorkspaceClient from databricks.sdk.service import catalog -from databricks.sdk.service.sql import Format +from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout from .config import databricks_vars, gcs_vars from .utilities import databricksify_inst_name, SchemaType from typing import List, Any @@ -221,9 +221,9 @@ def fetch_table_data( resp = w.statement_execution.execute_statement( warehouse_id=warehouse_id, statement=sql, - format=Format.JSON_ARRAY + format=Format.JSON_ARRAY, wait_timeout="10s", - on_wait_timeout="CONTINUE", + on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, ) status = getattr(resp, "status", None) From d28b48c7a4e43f2b69655422b5da963027e1b89c Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 16:45:23 -0500 Subject: [PATCH 07/26] feat: added all FE tables --- src/webapp/databricks.py | 50 +++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 7b8bc883..28daa2d0 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -226,32 +226,40 @@ def fetch_table_data( on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, ) - status = getattr(resp, "status", None) - if status and status.state == "SUCCEEDED" and getattr(resp, "result", None): - # resp.results is a list of row‐arrays, resp.schema is a list of column metadata - column_names = [col.name for col in resp.result.schema] - rows = resp.result.data_array + if resp.status and resp.status.state == "SUCCEEDED": + result = resp.result + if result and result.schema and result.data_array: + column_names = [col.name for col in result.schema] + rows = result.data_array + else: + raise ValueError("Result is missing schema or data_array.") else: - # A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set. stmt_id = getattr(resp, "statement_id", None) if not stmt_id: - raise ValueError( - f"fetch_table_data(): unexpected response state: {resp}" - ) + raise ValueError("Missing statement_id in initial response.") - # B. Poll until the statement succeeds (or fails/cancels) - status = resp.status.state if getattr(resp, "status", None) else None - while status not in ("SUCCEEDED", "FAILED", "CANCELED"): + # Poll until completion + while True: time.sleep(1) resp2 = w.statement_execution.get_statement(statement_id=stmt_id) - status = resp2.status.state if getattr(resp2, "status", None) else None - resp = resp2 - if status != "SUCCEEDED": - raise ValueError(f"fetch_table_data(): query ended with state {status}") + if resp2.status and resp2.status.state in ("SUCCEEDED", "FAILED", "CANCELED"): + break + + if not resp2.status or resp2.status.state != "SUCCEEDED": + raise ValueError(f"Query ended with state {resp2.status.state if resp2.status else 'UNKNOWN'}") + + result = resp2.result + if result and result.schema and result.data_array: + column_names = [col.name for col in result.schema] + rows = result.data_array + else: + raise ValueError("Result is missing schema or data_array.") + + # Final data transformation + if not rows: + return [] - # C. At this point, resp holds the final manifest and first chunk - column_names = [col.name for col in resp.result.schema] - rows = resp.result.data_array + if not all(isinstance(row, list) for row in rows): + raise TypeError("Result rows are not iterable lists") - # Transform each row (a list of values) into a dict - return [dict(zip(column_names, row)) for row in rows] + return [dict(zip(column_names, row)) for row in rows] \ No newline at end of file From f9ff9638c57f087fe0fed31baf6cce12ba0ea596 Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 16:49:25 -0500 Subject: [PATCH 08/26] feat: added all FE tables --- src/webapp/databricks.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 28daa2d0..6734f8f3 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -228,17 +228,17 @@ def fetch_table_data( if resp.status and resp.status.state == "SUCCEEDED": result = resp.result - if result and result.schema and result.data_array: - column_names = [col.name for col in result.schema] + manifest = resp.manifest + if result and manifest and manifest.schema and result.data_array: + column_names = [col.name for col in manifest.schema] rows = result.data_array else: - raise ValueError("Result is missing schema or data_array.") + raise ValueError("Missing result data or schema.") else: stmt_id = getattr(resp, "statement_id", None) if not stmt_id: raise ValueError("Missing statement_id in initial response.") - # Poll until completion while True: time.sleep(1) resp2 = w.statement_execution.get_statement(statement_id=stmt_id) @@ -249,13 +249,13 @@ def fetch_table_data( raise ValueError(f"Query ended with state {resp2.status.state if resp2.status else 'UNKNOWN'}") result = resp2.result - if result and result.schema and result.data_array: - column_names = [col.name for col in result.schema] + manifest = resp2.manifest + if result and manifest and manifest.schema and manifest.schema.columns and result.data_array: + column_names = [col.name for col in manifest.schema.columns] rows = result.data_array else: - raise ValueError("Result is missing schema or data_array.") + raise ValueError("Missing result data or schema.") - # Final data transformation if not rows: return [] From 174bc131ad51b73fe1b2dcb8c256c690d8bc484a Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 16:56:36 -0500 Subject: [PATCH 09/26] feat: added all FE tables --- src/webapp/databricks.py | 50 +++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 6734f8f3..7b8bc883 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -226,40 +226,32 @@ def fetch_table_data( on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, ) - if resp.status and resp.status.state == "SUCCEEDED": - result = resp.result - manifest = resp.manifest - if result and manifest and manifest.schema and result.data_array: - column_names = [col.name for col in manifest.schema] - rows = result.data_array - else: - raise ValueError("Missing result data or schema.") + status = getattr(resp, "status", None) + if status and status.state == "SUCCEEDED" and getattr(resp, "result", None): + # resp.results is a list of row‐arrays, resp.schema is a list of column metadata + column_names = [col.name for col in resp.result.schema] + rows = resp.result.data_array else: + # A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set. stmt_id = getattr(resp, "statement_id", None) if not stmt_id: - raise ValueError("Missing statement_id in initial response.") + raise ValueError( + f"fetch_table_data(): unexpected response state: {resp}" + ) - while True: + # B. Poll until the statement succeeds (or fails/cancels) + status = resp.status.state if getattr(resp, "status", None) else None + while status not in ("SUCCEEDED", "FAILED", "CANCELED"): time.sleep(1) resp2 = w.statement_execution.get_statement(statement_id=stmt_id) - if resp2.status and resp2.status.state in ("SUCCEEDED", "FAILED", "CANCELED"): - break - - if not resp2.status or resp2.status.state != "SUCCEEDED": - raise ValueError(f"Query ended with state {resp2.status.state if resp2.status else 'UNKNOWN'}") - - result = resp2.result - manifest = resp2.manifest - if result and manifest and manifest.schema and manifest.schema.columns and result.data_array: - column_names = [col.name for col in manifest.schema.columns] - rows = result.data_array - else: - raise ValueError("Missing result data or schema.") - - if not rows: - return [] + status = resp2.status.state if getattr(resp2, "status", None) else None + resp = resp2 + if status != "SUCCEEDED": + raise ValueError(f"fetch_table_data(): query ended with state {status}") - if not all(isinstance(row, list) for row in rows): - raise TypeError("Result rows are not iterable lists") + # C. At this point, resp holds the final manifest and first chunk + column_names = [col.name for col in resp.result.schema] + rows = resp.result.data_array - return [dict(zip(column_names, row)) for row in rows] \ No newline at end of file + # Transform each row (a list of values) into a dict + return [dict(zip(column_names, row)) for row in rows] From 0ce899dea7b3f65a753d90a1f1b6f668d79bcdc3 Mon Sep 17 00:00:00 2001 From: Mesh Date: Tue, 3 Jun 2025 17:10:32 -0500 Subject: [PATCH 10/26] feat: added all FE tables --- src/webapp/databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 7b8bc883..9ba960e6 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -229,7 +229,7 @@ def fetch_table_data( status = getattr(resp, "status", None) if status and status.state == "SUCCEEDED" and getattr(resp, "result", None): # resp.results is a list of row‐arrays, resp.schema is a list of column metadata - column_names = [col.name for col in resp.result.schema] + column_names = [col.name for col in resp.manifest.schema] rows = resp.result.data_array else: # A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set. @@ -250,7 +250,7 @@ def fetch_table_data( raise ValueError(f"fetch_table_data(): query ended with state {status}") # C. At this point, resp holds the final manifest and first chunk - column_names = [col.name for col in resp.result.schema] + column_names = [col.name for col in resp.manifest.schema] rows = resp.result.data_array # Transform each row (a list of values) into a dict From 160bf2358d38eccc4b1c18aa5f8ca1fd1a3d4e12 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:03:46 -0500 Subject: [PATCH 11/26] feat: added option for api auth --- src/webapp/authn.py | 7 +++++++ src/webapp/main.py | 13 ++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 31e53046..f626a444 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -52,6 +52,13 @@ def get_api_key( detail="Invalid or missing API Key", ) +def check_creds(username: str, password: str) -> bool: + if username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]: + return True + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Creds for worker job not correct", + ) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel diff --git a/src/webapp/main.py b/src/webapp/main.py index 0d2f90e0..79649a7c 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -6,6 +6,7 @@ import secrets from fastapi import FastAPI, Depends, HTTPException, status, Security from fastapi.responses import FileResponse +from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel from sqlalchemy.future import select from sqlalchemy import update @@ -37,6 +38,7 @@ create_access_token, get_api_key, get_api_key_hash, + check_creds, ) # Set the logging @@ -95,22 +97,27 @@ def read_root() -> Any: @app.post("/token-from-api-key") async def access_token_from_api_key( sql_session: Annotated[Session, Depends(get_session)], + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], api_key_enduser_tuple: str = Security(get_api_key), ) -> Token: """Generate a token from an API key.""" local_session.set(sql_session) + user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) - if not user: + valid = check_creds(form_data.username, form_data.password) + + if not user or not valid: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="API key not valid", + detail="Invalid API key and credentials", headers={"WWW-Authenticate": "X-API-KEY"}, ) + email = user.email if user else form_data.username access_token_expires = timedelta( minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]) ) access_token = create_access_token( - data={"sub": user.email}, expires_delta=access_token_expires + data={"sub": email}, expires_delta=access_token_expires ) return Token(access_token=access_token, token_type="bearer") From 4c457f3fdec99dfccb149e9a22984e6375db20d6 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:15:26 -0500 Subject: [PATCH 12/26] feat: added option for api auth --- src/webapp/authn.py | 4 +++- src/webapp/main_test.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index f626a444..6fddb95a 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -52,14 +52,16 @@ def get_api_key( detail="Invalid or missing API Key", ) + def check_creds(username: str, password: str) -> bool: if username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]: return True raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Creds for worker job not correct", + detail="Creds for webapp job not correct", ) + def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel Generates hashes that start with 2y. The hashing scheme recognizes both.""" diff --git a/src/webapp/main_test.py b/src/webapp/main_test.py index 9c6e99e3..a3ca47b1 100644 --- a/src/webapp/main_test.py +++ b/src/webapp/main_test.py @@ -149,6 +149,7 @@ def test_retrieve_token_gen_from_api_key(client: TestClient): response = client.post( "/token-from-api-key", headers={"X-API-KEY": "key_1"}, + data={"username": "fake", "password": "fake"} # required form fields ) assert response.status_code == 200 assert response.json()["token_type"] == "bearer" From ce89009ee1ea6146b9bd04bfa8be48c8a2d5cb0c Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:15:42 -0500 Subject: [PATCH 13/26] feat: added option for api auth --- src/webapp/main_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/webapp/main_test.py b/src/webapp/main_test.py index a3ca47b1..de44c5f8 100644 --- a/src/webapp/main_test.py +++ b/src/webapp/main_test.py @@ -149,7 +149,7 @@ def test_retrieve_token_gen_from_api_key(client: TestClient): response = client.post( "/token-from-api-key", headers={"X-API-KEY": "key_1"}, - data={"username": "fake", "password": "fake"} # required form fields + data={"username": "fake", "password": "fake"}, # required form fields ) assert response.status_code == 200 assert response.json()["token_type"] == "bearer" From 6bb6ea3f70de2c74327df8b2703492f0ea29b718 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:18:22 -0500 Subject: [PATCH 14/26] feat: added option for api auth --- src/webapp/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/webapp/config.py b/src/webapp/config.py index db6df647..7d62388b 100644 --- a/src/webapp/config.py +++ b/src/webapp/config.py @@ -15,6 +15,8 @@ "INITIAL_API_KEY_ID": "", "CATALOG_NAME": "", "SQL_WAREHOUSE_ID": "", + "USERNAME": "", + "PASSWORD": "", } # The INSTANCE_HOST is the private IP of CLoudSQL instance e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex) From 33006b61677ea903576f9b406e864f2e23bc2439 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:21:29 -0500 Subject: [PATCH 15/26] feat: added option for api auth --- src/webapp/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/webapp/main.py b/src/webapp/main.py index 79649a7c..9595596f 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -106,7 +106,7 @@ async def access_token_from_api_key( user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) valid = check_creds(form_data.username, form_data.password) - if not user or not valid: + if not user and not valid: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key and credentials", From 19aba9db9442f74339a5523dbfc9a0bb9f5104b9 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:27:22 -0500 Subject: [PATCH 16/26] feat: added option for api auth --- src/webapp/authn.py | 7 +------ src/webapp/main_test.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 6fddb95a..d00c2ec3 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -54,12 +54,7 @@ def get_api_key( def check_creds(username: str, password: str) -> bool: - if username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]: - return True - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Creds for webapp job not correct", - ) + return username == env_vars.get("USERNAME") and password == env_vars.get("PASSWORD") def verify_password(plain_password: str, hashed_password: str) -> bool: diff --git a/src/webapp/main_test.py b/src/webapp/main_test.py index de44c5f8..9e3e078e 100644 --- a/src/webapp/main_test.py +++ b/src/webapp/main_test.py @@ -13,6 +13,7 @@ get_session, ApiKeyTable, ) +from unittest.mock import patch from .authn import get_password_hash, get_api_key_hash from .test_helper import ( DATAKINDER, @@ -145,14 +146,14 @@ def test_get_root(client: TestClient): def test_retrieve_token_gen_from_api_key(client: TestClient): - """Test POST /token-from-api-key.""" - response = client.post( - "/token-from-api-key", - headers={"X-API-KEY": "key_1"}, - data={"username": "fake", "password": "fake"}, # required form fields - ) - assert response.status_code == 200 - assert response.json()["token_type"] == "bearer" + with patch.dict("os.environ", {"USERNAME": "fake", "PASSWORD": "fake"}): + response = client.post( + "/token-from-api-key", + headers={"X-API-KEY": "key_1"}, + data={"username": "fake", "password": "fake"}, + ) + assert response.status_code == 200 + assert response.json()["token_type"] == "bearer" def test_get_cross_isnt_users(client: TestClient): From 6c6efc6fe9b50520eb1aa6751024a66bc9eecfe5 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:37:41 -0500 Subject: [PATCH 17/26] feat: added option for api auth --- src/webapp/authn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index d00c2ec3..1f60c4ed 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -54,7 +54,7 @@ def get_api_key( def check_creds(username: str, password: str) -> bool: - return username == env_vars.get("USERNAME") and password == env_vars.get("PASSWORD") + return username == env_vars["USERNAME"] and password == env_vars["PASSWORD"] def verify_password(plain_password: str, hashed_password: str) -> bool: From 550010a9015e1bbbd8f6797a1ad50ad7019e3313 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:42:15 -0500 Subject: [PATCH 18/26] feat: added option for api auth --- src/webapp/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/webapp/main.py b/src/webapp/main.py index 9595596f..813a02b9 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -105,6 +105,7 @@ async def access_token_from_api_key( user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) valid = check_creds(form_data.username, form_data.password) + logger.info(f"user: {user}, creds valid: {valid}") if not user and not valid: raise HTTPException( From f3edcc59d9ad731fa97e34c5b4f39a034829caa9 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:43:04 -0500 Subject: [PATCH 19/26] feat: added option for api auth --- src/webapp/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/webapp/main.py b/src/webapp/main.py index 813a02b9..7c5839a4 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -105,7 +105,9 @@ async def access_token_from_api_key( user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) valid = check_creds(form_data.username, form_data.password) - logger.info(f"user: {user}, creds valid: {valid}") + logger.info(f"api_key input: {api_key_enduser_tuple}") + logger.info(f"user: {user}") + logger.info(f"valid creds: {valid}") if not user and not valid: raise HTTPException( From 3b576414944a2a47f06a044eac1e557b0bae55de Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 15:53:23 -0500 Subject: [PATCH 20/26] feat: added option for api auth --- src/webapp/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/webapp/main.py b/src/webapp/main.py index 7c5839a4..0bf86859 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -105,9 +105,9 @@ async def access_token_from_api_key( user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) valid = check_creds(form_data.username, form_data.password) - logger.info(f"api_key input: {api_key_enduser_tuple}") - logger.info(f"user: {user}") - logger.info(f"valid creds: {valid}") + print(f"api_key input: {api_key_enduser_tuple}") + print(f"user: {user}") + print(f"valid creds: {valid}") if not user and not valid: raise HTTPException( From 9ed2271be13a7cff78039c34baea030cc44115ea Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 16:18:00 -0500 Subject: [PATCH 21/26] feat: added option for api auth --- src/webapp/main.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/webapp/main.py b/src/webapp/main.py index 0bf86859..06b5b103 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -98,24 +98,18 @@ def read_root() -> Any: async def access_token_from_api_key( sql_session: Annotated[Session, Depends(get_session)], form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - api_key_enduser_tuple: str = Security(get_api_key), ) -> Token: """Generate a token from an API key.""" - local_session.set(sql_session) - user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) valid = check_creds(form_data.username, form_data.password) - print(f"api_key input: {api_key_enduser_tuple}") - print(f"user: {user}") - print(f"valid creds: {valid}") - if not user and not valid: + if not valid: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key and credentials", headers={"WWW-Authenticate": "X-API-KEY"}, ) - email = user.email if user else form_data.username + email = form_data.username access_token_expires = timedelta( minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]) ) From 4bdaea04eea4150b2c6b5768c8af6a2134c6ff74 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 16:30:33 -0500 Subject: [PATCH 22/26] feat: added option for api auth --- src/webapp/authn.py | 5 ----- src/webapp/main.py | 10 ++++++++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 1f60c4ed..1bcbdd64 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -16,11 +16,6 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_apikey_scheme = OAuth2PasswordBearer( - scheme_name="api_key_scheme", - tokenUrl="token-from-api-key", -) - api_key_header = APIKeyHeader(name="X-API-KEY", scheme_name="api-key", auto_error=False) # The INST value may be empty for Datakinder or cross-institution access. api_key_inst_header = APIKeyHeader( diff --git a/src/webapp/main.py b/src/webapp/main.py index 06b5b103..7c5839a4 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -98,18 +98,24 @@ def read_root() -> Any: async def access_token_from_api_key( sql_session: Annotated[Session, Depends(get_session)], form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + api_key_enduser_tuple: str = Security(get_api_key), ) -> Token: """Generate a token from an API key.""" + local_session.set(sql_session) + user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) valid = check_creds(form_data.username, form_data.password) + logger.info(f"api_key input: {api_key_enduser_tuple}") + logger.info(f"user: {user}") + logger.info(f"valid creds: {valid}") - if not valid: + if not user and not valid: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key and credentials", headers={"WWW-Authenticate": "X-API-KEY"}, ) - email = form_data.username + email = user.email if user else form_data.username access_token_expires = timedelta( minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]) ) From ea08663602b5cb1431355d80a6ff4b2691887aa4 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 16:32:32 -0500 Subject: [PATCH 23/26] feat: added option for api auth --- src/webapp/authn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 1bcbdd64..1f60c4ed 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -16,6 +16,11 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_apikey_scheme = OAuth2PasswordBearer( + scheme_name="api_key_scheme", + tokenUrl="token-from-api-key", +) + api_key_header = APIKeyHeader(name="X-API-KEY", scheme_name="api-key", auto_error=False) # The INST value may be empty for Datakinder or cross-institution access. api_key_inst_header = APIKeyHeader( From b69c428cb30cac53e11053fefb9e167132ea3639 Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 16:42:00 -0500 Subject: [PATCH 24/26] feat: added option for api auth --- src/webapp/main.py | 11 +++-------- src/webapp/main_test.py | 15 +++++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/webapp/main.py b/src/webapp/main.py index 7c5839a4..47d14964 100644 --- a/src/webapp/main.py +++ b/src/webapp/main.py @@ -97,30 +97,25 @@ def read_root() -> Any: @app.post("/token-from-api-key") async def access_token_from_api_key( sql_session: Annotated[Session, Depends(get_session)], - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], api_key_enduser_tuple: str = Security(get_api_key), ) -> Token: """Generate a token from an API key.""" local_session.set(sql_session) user = authenticate_api_key(api_key_enduser_tuple, local_session.get()) - valid = check_creds(form_data.username, form_data.password) - logger.info(f"api_key input: {api_key_enduser_tuple}") - logger.info(f"user: {user}") - logger.info(f"valid creds: {valid}") - if not user and not valid: + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key and credentials", headers={"WWW-Authenticate": "X-API-KEY"}, ) - email = user.email if user else form_data.username + access_token_expires = timedelta( minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]) ) access_token = create_access_token( - data={"sub": email}, expires_delta=access_token_expires + data={"sub": user.email}, expires_delta=access_token_expires ) return Token(access_token=access_token, token_type="bearer") diff --git a/src/webapp/main_test.py b/src/webapp/main_test.py index 9e3e078e..df4b3dfd 100644 --- a/src/webapp/main_test.py +++ b/src/webapp/main_test.py @@ -146,14 +146,13 @@ def test_get_root(client: TestClient): def test_retrieve_token_gen_from_api_key(client: TestClient): - with patch.dict("os.environ", {"USERNAME": "fake", "PASSWORD": "fake"}): - response = client.post( - "/token-from-api-key", - headers={"X-API-KEY": "key_1"}, - data={"username": "fake", "password": "fake"}, - ) - assert response.status_code == 200 - assert response.json()["token_type"] == "bearer" + """Test POST /token-from-api-key.""" + response = client.post( + "/token-from-api-key", + headers={"X-API-KEY": "key_1"}, + ) + assert response.status_code == 200 + assert response.json()["token_type"] == "bearer" def test_get_cross_isnt_users(client: TestClient): From 788093fc40c8b7780d39c29fc84f4d2e804a0e2f Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 16:54:48 -0500 Subject: [PATCH 25/26] feat: added option for api auth --- src/webapp/routers/data.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 17e35301..d897ca9a 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1028,12 +1028,12 @@ def get_upload_url( def get_top_features( inst_id: str, run_id: str, - current_user: Annotated[BaseUser, Depends(get_current_active_user)], + #current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - has_access_to_inst_or_err(inst_id, current_user) + #has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1072,12 +1072,12 @@ def get_top_features( def get_support_overview( inst_id: str, run_id: str, - current_user: Annotated[BaseUser, Depends(get_current_active_user)], + #current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - has_access_to_inst_or_err(inst_id, current_user) + #has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1115,12 +1115,12 @@ def get_support_overview( def get_feature_value( inst_id: str, run_id: str, - current_user: Annotated[BaseUser, Depends(get_current_active_user)], + #current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - has_access_to_inst_or_err(inst_id, current_user) + #has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1158,12 +1158,12 @@ def get_feature_value( def get_confusion_matrix( inst_id: str, run_id: str, - current_user: Annotated[BaseUser, Depends(get_current_active_user)], + ##current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - has_access_to_inst_or_err(inst_id, current_user) + #has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1201,12 +1201,12 @@ def get_confusion_matrix( def get_roc_curve( inst_id: str, run_id: str, - current_user: Annotated[BaseUser, Depends(get_current_active_user)], + #current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - has_access_to_inst_or_err(inst_id, current_user) + #has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() From 1010235c4d3fd5de08843fb557c33daaaef7e16b Mon Sep 17 00:00:00 2001 From: Mesh Date: Wed, 4 Jun 2025 16:57:01 -0500 Subject: [PATCH 26/26] feat: added option for api auth --- src/webapp/routers/data.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index d897ca9a..56a01675 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1028,12 +1028,12 @@ def get_upload_url( def get_top_features( inst_id: str, run_id: str, - #current_user: Annotated[BaseUser, Depends(get_current_active_user)], + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - #has_access_to_inst_or_err(inst_id, current_user) + # has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1072,12 +1072,12 @@ def get_top_features( def get_support_overview( inst_id: str, run_id: str, - #current_user: Annotated[BaseUser, Depends(get_current_active_user)], + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - #has_access_to_inst_or_err(inst_id, current_user) + # has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1115,12 +1115,12 @@ def get_support_overview( def get_feature_value( inst_id: str, run_id: str, - #current_user: Annotated[BaseUser, Depends(get_current_active_user)], + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - #has_access_to_inst_or_err(inst_id, current_user) + # has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1163,7 +1163,7 @@ def get_confusion_matrix( ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - #has_access_to_inst_or_err(inst_id, current_user) + # has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1201,12 +1201,12 @@ def get_confusion_matrix( def get_roc_curve( inst_id: str, run_id: str, - #current_user: Annotated[BaseUser, Depends(get_current_active_user)], + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - #has_access_to_inst_or_err(inst_id, current_user) + # has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get()