Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
53ebd92
feat: added all FE tables
Mesh-ach Jun 3, 2025
1c160fb
feat: added all FE tables
Mesh-ach Jun 3, 2025
549f49f
feat: added all FE tables
Mesh-ach Jun 3, 2025
d69ed25
feat: added all FE tables
Mesh-ach Jun 3, 2025
63e7ce2
feat: added all FE tables
Mesh-ach Jun 3, 2025
98f3f97
feat: added all FE tables
Mesh-ach Jun 3, 2025
d28b48c
feat: added all FE tables
Mesh-ach Jun 3, 2025
f9ff963
feat: added all FE tables
Mesh-ach Jun 3, 2025
174bc13
feat: added all FE tables
Mesh-ach Jun 3, 2025
0ce899d
feat: added all FE tables
Mesh-ach Jun 3, 2025
5e87ef8
Merge pull request #92 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
160bf23
feat: added option for api auth
Mesh-ach Jun 4, 2025
4c457f3
feat: added option for api auth
Mesh-ach Jun 4, 2025
ce89009
feat: added option for api auth
Mesh-ach Jun 4, 2025
6bb6ea3
feat: added option for api auth
Mesh-ach Jun 4, 2025
33006b6
feat: added option for api auth
Mesh-ach Jun 4, 2025
19aba9d
feat: added option for api auth
Mesh-ach Jun 4, 2025
aba20d2
Merge pull request #93 from datakind/Validation-Errors
Mesh-ach Jun 4, 2025
6c6efc6
feat: added option for api auth
Mesh-ach Jun 4, 2025
550010a
feat: added option for api auth
Mesh-ach Jun 4, 2025
f3edcc5
feat: added option for api auth
Mesh-ach Jun 4, 2025
9831998
Merge pull request #94 from datakind/Validation-Errors
Mesh-ach Jun 4, 2025
3b57641
feat: added option for api auth
Mesh-ach Jun 4, 2025
9ed2271
feat: added option for api auth
Mesh-ach Jun 4, 2025
4bdaea0
feat: added option for api auth
Mesh-ach Jun 4, 2025
ea08663
feat: added option for api auth
Mesh-ach Jun 4, 2025
b69c428
feat: added option for api auth
Mesh-ach Jun 4, 2025
788093f
feat: added option for api auth
Mesh-ach Jun 4, 2025
1010235
feat: added option for api auth
Mesh-ach Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/webapp/authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def get_api_key(
)


def check_creds(username: str, password: str) -> bool:
return username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]


def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel
Generates hashes that start with 2y. The hashing scheme recognizes both."""
Expand Down
4 changes: 4 additions & 0 deletions src/webapp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
"API_KEY_ISSUERS": [],
"INITIAL_API_KEY": "",
"INITIAL_API_KEY_ID": "",
"CATALOG_NAME": "",
"SQL_WAREHOUSE_ID": "",
"USERNAME": "",
"PASSWORD": "",
}

# The INSTANCE_HOST is the private IP of CLoudSQL instance e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex)
Expand Down
66 changes: 65 additions & 1 deletion src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from pydantic import BaseModel
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog

from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout
from .config import databricks_vars, gcs_vars
from .utilities import databricksify_inst_name, SchemaType
from typing import List, Any
import time

# List of data medallion levels
MEDALLION_LEVELS = ["silver", "gold", "bronze"]
Expand Down Expand Up @@ -191,3 +193,65 @@ def delete_inst(self, inst_name: str) -> None:
full_name=f"{cat_name}.{db_inst_name}_{medallion}.{table}"
)
w.schemas.delete(full_name=f"{cat_name}.{db_inst_name}_{medallion}")

def fetch_table_data(
self,
catalog_name: Any,
schema_name: Any,
table_name: Any,
warehouse_id: Any,
limit: int = 1000,
) -> List[dict[str, Any]]:
"""
Runs a simple SELECT * FROM <catalog>.<schema>.<table> LIMIT <limit>
against the specified SQL warehouse, and returns a list of row‐dicts.
"""
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
if not w:
raise ValueError(
"fetch_table_data(): could not initialize WorkspaceClient."
)

fq_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`"
sql = f"SELECT * FROM {fq_table} LIMIT {limit}"

resp = w.statement_execution.execute_statement(
warehouse_id=warehouse_id,
statement=sql,
format=Format.JSON_ARRAY,
wait_timeout="10s",
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE,
)

status = getattr(resp, "status", None)
if status and status.state == "SUCCEEDED" and getattr(resp, "result", None):
# resp.results is a list of row‐arrays, resp.schema is a list of column metadata
column_names = [col.name for col in resp.manifest.schema]
rows = resp.result.data_array
else:
# A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set.
stmt_id = getattr(resp, "statement_id", None)
if not stmt_id:
raise ValueError(
f"fetch_table_data(): unexpected response state: {resp}"
)

# B. Poll until the statement succeeds (or fails/cancels)
status = resp.status.state if getattr(resp, "status", None) else None
while status not in ("SUCCEEDED", "FAILED", "CANCELED"):
time.sleep(1)
resp2 = w.statement_execution.get_statement(statement_id=stmt_id)
status = resp2.status.state if getattr(resp2, "status", None) else None
resp = resp2
if status != "SUCCEEDED":
raise ValueError(f"fetch_table_data(): query ended with state {status}")

# C. At this point, resp holds the final manifest and first chunk
column_names = [col.name for col in resp.manifest.schema]
rows = resp.result.data_array

# Transform each row (a list of values) into a dict
return [dict(zip(column_names, row)) for row in rows]
7 changes: 6 additions & 1 deletion src/webapp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import secrets
from fastapi import FastAPI, Depends, HTTPException, status, Security
from fastapi.responses import FileResponse
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from sqlalchemy.future import select
from sqlalchemy import update
Expand Down Expand Up @@ -37,6 +38,7 @@
create_access_token,
get_api_key,
get_api_key_hash,
check_creds,
)

# Set the logging
Expand Down Expand Up @@ -99,13 +101,16 @@ async def access_token_from_api_key(
) -> Token:
"""Generate a token from an API key."""
local_session.set(sql_session)

user = authenticate_api_key(api_key_enduser_tuple, local_session.get())

if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key not valid",
detail="Invalid API key and credentials",
headers={"WWW-Authenticate": "X-API-KEY"},
)

access_token_expires = timedelta(
minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"])
)
Expand Down
1 change: 1 addition & 0 deletions src/webapp/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_session,
ApiKeyTable,
)
from unittest.mock import patch
from .authn import get_password_hash, get_api_key_hash
from .test_helper import (
DATAKINDER,
Expand Down
220 changes: 220 additions & 0 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import logging
from sqlalchemy.exc import IntegrityError
from ..config import env_vars

from ..utilities import (
has_access_to_inst_or_err,
Expand All @@ -31,8 +32,10 @@
local_session,
BatchTable,
FileTable,
InstTable,
)

from ..databricks import DatabricksControl
from ..gcsdbutils import update_db_from_bucket

from ..gcsutil import StorageControl
Expand Down Expand Up @@ -1018,3 +1021,220 @@ def get_upload_url(
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get SHAP Values for Inference
@router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str)
def get_top_features(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
table_name=f"sample_inference_{run_id}_features_with_most_impact",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get SHAP Values for Inference
@router.get("/{inst_id}/inference/support-overview/{run_id}", response_model=str)
def get_support_overview(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
table_name=f"sample_inference_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/inference/feature_value/{run_id}", response_model=str)
def get_feature_value(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
table_name=f"sample_inference_{run_id}_shap_feature_importance",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/training/confusion_matrix/{run_id}", response_model=str)
def get_confusion_matrix(
inst_id: str,
run_id: str,
##current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
table_name=f"sample_training_{run_id}_confusion_matrix",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/training/roc_curve/{run_id}", response_model=str)
def get_roc_curve(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
table_name=f"sample_training_{run_id}_roc_curve",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
Loading
Loading