Skip to content

Commit 03f0275

Browse files
authored
Merge pull request #178 from datakind/BackfillEndpoint
Feat: Added backfill endpoint
2 parents 1969f73 + a176d62 commit 03f0275

File tree

1 file changed

+95
-3
lines changed

1 file changed

+95
-3
lines changed

src/webapp/routers/models.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jsonpickle
66
from fastapi import APIRouter, Depends, HTTPException, status
77
from pydantic import BaseModel
8-
from sqlalchemy import and_
8+
from sqlalchemy import and_, update, or_
99
from sqlalchemy.orm import Session
1010
from sqlalchemy.future import select
1111
from ..databricks import DatabricksControl, DatabricksInferenceRunRequest
@@ -568,7 +568,7 @@ def trigger_inference_run(
568568
) from e
569569
triggered_timestamp = datetime.now()
570570
latest_model_version = databricks_control.fetch_model_version(
571-
catalog_name=env_vars["CATALOG_NAME"],
571+
catalog_name=str(env_vars["CATALOG_NAME"]),
572572
inst_name=inst_result[0][0].name,
573573
model_name=model_name,
574574
)
@@ -628,9 +628,101 @@ def get_model_versions(
628628
print(f"Converted model name {transformed_model_name}")
629629

630630
latest_model_version = databricks_control.fetch_model_version(
631-
catalog_name=env_vars["CATALOG_NAME"],
631+
catalog_name=str(env_vars["CATALOG_NAME"]),
632632
inst_name=f"{query_result[0][0].name}",
633633
model_name=transformed_model_name,
634634
)
635635

636636
return latest_model_version
637+
638+
639+
@router.post("/{inst_id}/models/{model_name}/backfill-model-runs")
640+
def backfill_model_runs(
641+
inst_id: str,
642+
model_name: str,
643+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
644+
sql_session: Annotated[Session, Depends(get_session)],
645+
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
646+
) -> Any:
647+
"""Backfills missing model run metadata and returns the latest model version info.
648+
649+
Temporary endpoint to populate model_run_id and model_version on existing jobs for this model.
650+
Use only when backfilling historical job runs, not for regular operation.
651+
"""
652+
transformed_model_name = str(decode_url_piece(model_name)).strip()
653+
has_access_to_inst_or_err(inst_id, current_user)
654+
655+
# Load institution
656+
local_session.set(sql_session)
657+
inst_row = (
658+
local_session.get()
659+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
660+
.all()
661+
)
662+
663+
model_id = (
664+
local_session.get()
665+
.execute(
666+
select(ModelTable).where(
667+
and_(
668+
ModelTable.inst_id == str_to_uuid(inst_id),
669+
ModelTable.name == f"{inst_row[0][0].name}",
670+
)
671+
)
672+
)
673+
.all()
674+
)
675+
676+
if not inst_row or len(inst_row) == 0:
677+
raise HTTPException(
678+
status_code=status.HTTP_404_NOT_FOUND,
679+
detail="Institution not found.",
680+
)
681+
if len(inst_row) > 1:
682+
raise HTTPException(
683+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
684+
detail="Institution duplicates found.",
685+
)
686+
687+
latest_mv = databricks_control.fetch_model_version(
688+
catalog_name=str(env_vars["CATALOG_NAME"]),
689+
inst_name=f"{inst_row[0][0].name}",
690+
model_name=transformed_model_name,
691+
)
692+
693+
mv_version = str(latest_mv.version)
694+
mv_run_id = str(latest_mv.run_id)
695+
696+
# UPDATE existing jobs for this model (only those missing values)
697+
stmt = (
698+
update(JobTable)
699+
.where(JobTable.model_id == model_id[0][0].id)
700+
.where(
701+
or_(
702+
JobTable.model_run_id.is_(None),
703+
JobTable.model_run_id == "",
704+
JobTable.model_version.is_(None),
705+
JobTable.model_version == "",
706+
)
707+
)
708+
.values(model_run_id=mv_run_id, model_version=mv_version)
709+
.returning(
710+
JobTable.id,
711+
JobTable.model_id,
712+
JobTable.model_run_id,
713+
JobTable.model_version,
714+
)
715+
)
716+
717+
result = local_session.get().execute(stmt)
718+
updated_rows = [dict(r._mapping) for r in result.fetchall()]
719+
local_session.get().commit()
720+
721+
return {
722+
"updated_count": len(updated_rows),
723+
"updated_rows": updated_rows,
724+
"latest_model_version": {
725+
"version": mv_version,
726+
"run_id": mv_run_id,
727+
},
728+
}

0 commit comments

Comments
 (0)