|
5 | 5 | import jsonpickle |
6 | 6 | from fastapi import APIRouter, Depends, HTTPException, status |
7 | 7 | from pydantic import BaseModel |
8 | | -from sqlalchemy import and_ |
| 8 | +from sqlalchemy import and_, update, or_ |
9 | 9 | from sqlalchemy.orm import Session |
10 | 10 | from sqlalchemy.future import select |
11 | 11 | from ..databricks import DatabricksControl, DatabricksInferenceRunRequest |
@@ -568,7 +568,7 @@ def trigger_inference_run( |
568 | 568 | ) from e |
569 | 569 | triggered_timestamp = datetime.now() |
570 | 570 | latest_model_version = databricks_control.fetch_model_version( |
571 | | - catalog_name=env_vars["CATALOG_NAME"], |
| 571 | + catalog_name=str(env_vars["CATALOG_NAME"]), |
572 | 572 | inst_name=inst_result[0][0].name, |
573 | 573 | model_name=model_name, |
574 | 574 | ) |
@@ -628,9 +628,101 @@ def get_model_versions( |
628 | 628 | print(f"Converted model name {transformed_model_name}") |
629 | 629 |
|
630 | 630 | latest_model_version = databricks_control.fetch_model_version( |
631 | | - catalog_name=env_vars["CATALOG_NAME"], |
| 631 | + catalog_name=str(env_vars["CATALOG_NAME"]), |
632 | 632 | inst_name=f"{query_result[0][0].name}", |
633 | 633 | model_name=transformed_model_name, |
634 | 634 | ) |
635 | 635 |
|
636 | 636 | return latest_model_version |
| 637 | + |
| 638 | + |
| 639 | +@router.post("/{inst_id}/models/{model_name}/backfill-model-runs") |
| 640 | +def backfill_model_runs( |
| 641 | + inst_id: str, |
| 642 | + model_name: str, |
| 643 | + current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 644 | + sql_session: Annotated[Session, Depends(get_session)], |
| 645 | + databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)], |
| 646 | +) -> Any: |
| 647 | + """Backfills missing model run metadata and returns the latest model version info. |
| 648 | +
|
| 649 | + Temporary endpoint to populate model_run_id and model_version on existing jobs for this model. |
| 650 | + Use only when backfilling historical job runs, not for regular operation. |
| 651 | + """ |
| 652 | + transformed_model_name = str(decode_url_piece(model_name)).strip() |
| 653 | + has_access_to_inst_or_err(inst_id, current_user) |
| 654 | + |
| 655 | + # Load institution |
| 656 | + local_session.set(sql_session) |
| 657 | + inst_row = ( |
| 658 | + local_session.get() |
| 659 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 660 | + .all() |
| 661 | + ) |
| 662 | + |
| 663 | + model_id = ( |
| 664 | + local_session.get() |
| 665 | + .execute( |
| 666 | + select(ModelTable).where( |
| 667 | + and_( |
| 668 | + ModelTable.inst_id == str_to_uuid(inst_id), |
| 669 | + ModelTable.name == f"{inst_row[0][0].name}", |
| 670 | + ) |
| 671 | + ) |
| 672 | + ) |
| 673 | + .all() |
| 674 | + ) |
| 675 | + |
| 676 | + if not inst_row or len(inst_row) == 0: |
| 677 | + raise HTTPException( |
| 678 | + status_code=status.HTTP_404_NOT_FOUND, |
| 679 | + detail="Institution not found.", |
| 680 | + ) |
| 681 | + if len(inst_row) > 1: |
| 682 | + raise HTTPException( |
| 683 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 684 | + detail="Institution duplicates found.", |
| 685 | + ) |
| 686 | + |
| 687 | + latest_mv = databricks_control.fetch_model_version( |
| 688 | + catalog_name=str(env_vars["CATALOG_NAME"]), |
| 689 | + inst_name=f"{inst_row[0][0].name}", |
| 690 | + model_name=transformed_model_name, |
| 691 | + ) |
| 692 | + |
| 693 | + mv_version = str(latest_mv.version) |
| 694 | + mv_run_id = str(latest_mv.run_id) |
| 695 | + |
| 696 | + # UPDATE existing jobs for this model (only those missing values) |
| 697 | + stmt = ( |
| 698 | + update(JobTable) |
| 699 | + .where(JobTable.model_id == model_id[0][0].id) |
| 700 | + .where( |
| 701 | + or_( |
| 702 | + JobTable.model_run_id.is_(None), |
| 703 | + JobTable.model_run_id == "", |
| 704 | + JobTable.model_version.is_(None), |
| 705 | + JobTable.model_version == "", |
| 706 | + ) |
| 707 | + ) |
| 708 | + .values(model_run_id=mv_run_id, model_version=mv_version) |
| 709 | + .returning( |
| 710 | + JobTable.id, |
| 711 | + JobTable.model_id, |
| 712 | + JobTable.model_run_id, |
| 713 | + JobTable.model_version, |
| 714 | + ) |
| 715 | + ) |
| 716 | + |
| 717 | + result = local_session.get().execute(stmt) |
| 718 | + updated_rows = [dict(r._mapping) for r in result.fetchall()] |
| 719 | + local_session.get().commit() |
| 720 | + |
| 721 | + return { |
| 722 | + "updated_count": len(updated_rows), |
| 723 | + "updated_rows": updated_rows, |
| 724 | + "latest_model_version": { |
| 725 | + "version": mv_version, |
| 726 | + "run_id": mv_run_id, |
| 727 | + }, |
| 728 | + } |
0 commit comments