Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions src/webapp/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jsonpickle
from fastapi import APIRouter, Depends, HTTPException, status
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow this Depends logic is cool

from pydantic import BaseModel
from sqlalchemy import and_
from sqlalchemy import and_, update, or_
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from ..databricks import DatabricksControl, DatabricksInferenceRunRequest
Expand Down Expand Up @@ -568,7 +568,7 @@ def trigger_inference_run(
) from e
triggered_timestamp = datetime.now()
latest_model_version = databricks_control.fetch_model_version(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=inst_result[0][0].name,
model_name=model_name,
)
Expand Down Expand Up @@ -628,9 +628,98 @@ def get_model_versions(
print(f"Converted model name {transformed_model_name}")

latest_model_version = databricks_control.fetch_model_version(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=f"{query_result[0][0].name}",
model_name=transformed_model_name,
)

return latest_model_version


@router.post("/{inst_id}/models/{model_name}/backfill-model-runs")
def backfill_model_runs(
inst_id: str,
model_name: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just add some quick documentation here saying that this will backfill the LATEST model run ID, and to use this endpoint carefully in the future? Maybe we should delete this later.. we'll figure it out.

transformed_model_name = str(decode_url_piece(model_name)).strip()
has_access_to_inst_or_err(inst_id, current_user)

# Load institution
local_session.set(sql_session)
inst_row = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)

model_id = (
local_session.get()
.execute(
select(ModelTable).where(
and_(
ModelTable.inst_id == str_to_uuid(inst_id),
ModelTable.name == f"{inst_row[0][0].name}",
)
)
)
.all()
)

if not inst_row or len(inst_row) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(inst_row) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

# Get latest model version from Databricks
latest_mv = databricks_control.fetch_model_version(
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=f"{inst_row[0][0].name}",
model_name=transformed_model_name,
)

# Coerce types as needed
mv_version = str(latest_mv.version)
mv_run_id = str(latest_mv.run_id)

# UPDATE existing jobs for this model (only those missing values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah so this only updates if it's missing? I like that... it'll prevent someone from messing up the table in the future.

stmt = (
update(JobTable)
.where(JobTable.model_id == model_id[0][0].id)
.where(
or_(
JobTable.model_run_id.is_(None),
JobTable.model_run_id == "",
JobTable.model_version.is_(None),
JobTable.model_version == "",
)
)
.values(model_run_id=mv_run_id, model_version=mv_version)
.returning(
JobTable.id,
JobTable.model_id,
JobTable.model_run_id,
JobTable.model_version,
)
)

result = local_session.get().execute(stmt)
updated_rows = [dict(r._mapping) for r in result.fetchall()]
local_session.get().commit()

return {
"updated_count": len(updated_rows),
"updated_rows": updated_rows,
"latest_model_version": {
"version": mv_version,
"run_id": mv_run_id,
},
}
Loading