Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"pandera~=0.13",
"mlflow~=2.15.0",
"cachetools",
"types-cachetools",
]

[project.urls]
Expand Down
45 changes: 35 additions & 10 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def _sha256_json(obj: Any) -> str:

L1_RESP_CACHE_TTL = int("600") # seconds
L1_VER_CACHE_TTL = int("3600") # seconds
L1_RESP_CACHE = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL)
L1_VER_CACHE = TTLCache(maxsize=256, ttl=L1_VER_CACHE_TTL)
L1_RESP_CACHE: Any = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL)
L1_VER_CACHE: Any = TTLCache(maxsize=256, ttl=L1_VER_CACHE_TTL)
_L1_LOCK = threading.RLock()


Expand Down Expand Up @@ -251,7 +251,6 @@ def run_pdp_inference(
], # is this value the same PER environ? dev/staging/prod
"gcp_bucket_name": req.gcp_external_bucket_name,
"model_name": req.model_name,
"model_type": req.model_type,
"notification_email": req.email,
},
)
Expand Down Expand Up @@ -333,7 +332,7 @@ def fetch_table_data(
inst_name: str,
table_name: str,
warehouse_id: str,
) -> List[Dict[str, Any]]:
) -> Any:
"""
Execute SELECT * via Databricks SQL Statement Execution API using EXTERNAL_LINKS.
Blocks server-side for up to 30s; if not SUCCEEDED, raises. Downloads presigned
Expand Down Expand Up @@ -366,9 +365,9 @@ def fetch_table_data(

if not ver_resp.status or ver_resp.status.state != StatementState.SUCCEEDED:
raise TimeoutError("DESCRIBE HISTORY did not finish within 30s")
cols = [c.name for c in ver_resp.manifest.schema.columns]
cols = [c.name for c in ver_resp.manifest.schema.columns] # type: ignore
idx = {n: i for i, n in enumerate(cols)}
rows = ver_resp.result.data_array or []
rows = ver_resp.result.data_array or [] # type: ignore
if not rows or "version" not in idx:
raise ValueError("DESCRIBE HISTORY returned no version")
table_version = str(rows[0][idx["version"]])
Expand Down Expand Up @@ -432,13 +431,13 @@ def fetch_table_data(
resp.manifest and resp.manifest.schema and resp.manifest.schema.columns
):
raise ValueError("Schema/columns missing (EXTERNAL_LINKS).")
cols: List[str] = []
cols: List[str] = [] # type: ignore
for c in resp.manifest.schema.columns:
if c.name is None:
raise ValueError("Encountered a column without a name.")
cols.append(c.name)

records: List[Dict[str, Any]] = []
records: Any = []

# Helper: consume one chunk-like object (first result or subsequent chunk)
def _consume_chunk(chunk_obj: Any) -> int | None:
Expand Down Expand Up @@ -504,7 +503,9 @@ def _consume_chunk(chunk_obj: Any) -> int | None:
pass
return records

def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str):
def fetch_model_version(
self, catalog_name: str, inst_name: str, model_name: str
) -> Any:
schema = databricksify_inst_name(inst_name)
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"

Expand All @@ -521,7 +522,7 @@ def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str
)
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")

model_versions = list(
model_versions: Any = list(
w.model_versions.list(
full_name=model_name_path,
)
Expand All @@ -534,6 +535,30 @@ def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str

return latest_version

def delete_model(self, catalog_name: str, inst_name: str, model_name: str) -> None:
schema = databricksify_inst_name(inst_name)
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"

try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars["DATABRICKS_HOST_URL"],
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")

try:
w.registered_models.delete(full_name=model_name_path)
LOGGER.info("Deleted registration model: %s", model_name_path)
except Exception:
LOGGER.exception("Failed to delete registered model: %s", model_name_path)
raise

def get_key_for_file(
self, mapping: Dict[str, Any], file_name: str
) -> Optional[str]:
Expand Down
50 changes: 49 additions & 1 deletion src/webapp/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,54 @@ def read_inst_model(
}


@router.delete("/{inst_id}/models/{model_name}")
def delete_model(
inst_id: str,
model_name: str,
delete_from_databricks: bool,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
) -> Any:
transformed_model_name = str(decode_url_piece(model_name)).strip()
has_access_to_inst_or_err(inst_id, current_user)
model_owner_and_higher_or_err(current_user, "modify batch")

local_session.set(sql_session)
sess = local_session.get()

query_result = sess.execute(
select(InstTable).where(InstTable.id == str_to_uuid(inst_id))
).all()

model_list = sess.execute(
select(ModelTable).where(
ModelTable.name == str_to_uuid(model_name),
ModelTable.inst_id == str_to_uuid(inst_id),
)
).scalar_one_or_none()
if model_list is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Model not found."
)

if delete_from_databricks:
# 2) Optionally Delete models from databricks itself
databricks_control.delete_model(
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=f"{query_result[0][0].name}",
model_name=transformed_model_name,
)

sess.delete(model_list)
sess.commit()
return {
"inst_id": inst_id,
"model_name": transformed_model_name,
"deleted_from_databricks": delete_from_databricks,
}


@router.get("/{inst_id}/models/{model_name}/runs", response_model=list[RunInfo])
def read_inst_model_outputs(
inst_id: str,
Expand Down Expand Up @@ -710,7 +758,7 @@ def backfill_model_runs(
.values(model_run_id=mv_run_id, model_version=mv_version)
)
result = local_session.get().execute(stmt)
updated_count = result.rowcount or 0
updated_count = result.rowcount or 0 # type: ignore
local_session.get().commit()

return {
Expand Down
Loading
Loading