Skip to content

Commit 131eb21

Browse files
authored
Merge pull request #186 from datakind/AddCustomJob
feat: developed function for adding custom jobs with institution and model validation
2 parents 7d6b3fa + c453783 commit 131eb21

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

src/webapp/routers/data.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import uuid
44
from datetime import datetime, date
5-
from typing import Annotated, Any, Dict, List, cast, IO, Optional, Tuple
5+
from typing import Annotated, Any, Dict, List, Optional, Tuple
66
from pydantic import BaseModel, Field
7-
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
7+
from fastapi import APIRouter, Depends, HTTPException, status, Response
88
from sqlalchemy import and_, or_
99
from sqlalchemy.orm import Session
1010
from sqlalchemy.future import select
@@ -33,6 +33,8 @@
3333
BatchTable,
3434
FileTable,
3535
InstTable,
36+
JobTable,
37+
ModelTable,
3638
SchemaRegistryTable,
3739
DocType,
3840
)
@@ -41,6 +43,7 @@
4143
from ..gcsdbutils import update_db_from_bucket
4244

4345
from ..gcsutil import StorageControl
46+
from ..config import env_vars
4447

4548
# Set the logging
4649
logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s")
@@ -1345,3 +1348,82 @@ def get_upload_url(
13451348
except ValueError as ve:
13461349
# Return a 400 error with the specific message from ValueError
13471350
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1351+
1352+
1353+
@router.get("/{inst_id}/add-custom-school-job/{job_run_id}")
1354+
def add_custom_school_job(
1355+
inst_id: str,
1356+
job_run_id: str,
1357+
model_name: str,
1358+
sql_session: Annotated[Session, Depends(get_session)],
1359+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1360+
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
1361+
) -> Any:
1362+
"""Fill in a JobTable ."""
1363+
has_access_to_inst_or_err(inst_id, current_user)
1364+
has_full_data_access_or_err(current_user, "this model")
1365+
local_session.set(sql_session)
1366+
1367+
model_name = decode_url_piece(model_name)
1368+
inst_result = (
1369+
local_session.get()
1370+
.execute(
1371+
select(InstTable).where(
1372+
and_(
1373+
InstTable.id == str_to_uuid(inst_id),
1374+
)
1375+
)
1376+
)
1377+
.all()
1378+
)
1379+
1380+
query_result = (
1381+
local_session.get()
1382+
.execute(
1383+
select(ModelTable).where(
1384+
and_(
1385+
ModelTable.name == model_name,
1386+
ModelTable.inst_id == str_to_uuid(inst_id),
1387+
)
1388+
)
1389+
)
1390+
.all()
1391+
)
1392+
1393+
if not inst_result or not query_result:
1394+
raise HTTPException(
1395+
status_code=status.HTTP_404_NOT_FOUND,
1396+
detail="Institution or model does not exist.",
1397+
)
1398+
1399+
try:
1400+
triggered_timestamp = datetime.now()
1401+
latest_model_version = databricks_control.fetch_model_version(
1402+
catalog_name=str(env_vars["CATALOG_NAME"]),
1403+
inst_name=inst_result[0][0].name,
1404+
model_name=model_name,
1405+
)
1406+
job = JobTable(
1407+
id=job_run_id,
1408+
triggered_at=triggered_timestamp,
1409+
created_by=str_to_uuid(current_user.user_id),
1410+
batch_name="No batch name (manual custom school job)",
1411+
model_id=query_result[0][0].id,
1412+
output_valid=False,
1413+
model_version=latest_model_version.version,
1414+
model_run_id=latest_model_version.run_id,
1415+
)
1416+
local_session.get().add(job)
1417+
1418+
return {
1419+
"inst_id": inst_id,
1420+
"m_name": model_name,
1421+
"run_id": job_run_id,
1422+
"created_by": current_user.user_id,
1423+
"triggered_at": triggered_timestamp,
1424+
"model_version": latest_model_version.version,
1425+
"model_run_id": latest_model_version.run_id,
1426+
}
1427+
except ValueError as ve:
1428+
# Return a 400 error with the specific message from ValueError
1429+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))

0 commit comments

Comments
 (0)