|
2 | 2 |
|
3 | 3 | import uuid |
4 | 4 | from datetime import datetime, date |
5 | | -from typing import Annotated, Any, Dict, List, cast, IO, Optional, Tuple |
| 5 | +from typing import Annotated, Any, Dict, List, Optional, Tuple |
6 | 6 | from pydantic import BaseModel, Field |
7 | | -from fastapi import APIRouter, Depends, HTTPException, status, Response, Query |
| 7 | +from fastapi import APIRouter, Depends, HTTPException, status, Response |
8 | 8 | from sqlalchemy import and_, or_ |
9 | 9 | from sqlalchemy.orm import Session |
10 | 10 | from sqlalchemy.future import select |
|
33 | 33 | BatchTable, |
34 | 34 | FileTable, |
35 | 35 | InstTable, |
| 36 | + JobTable, |
| 37 | + ModelTable, |
36 | 38 | SchemaRegistryTable, |
37 | 39 | DocType, |
38 | 40 | ) |
|
41 | 43 | from ..gcsdbutils import update_db_from_bucket |
42 | 44 |
|
43 | 45 | from ..gcsutil import StorageControl |
| 46 | +from ..config import env_vars |
44 | 47 |
|
45 | 48 | # Set the logging |
46 | 49 | logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s") |
@@ -1345,3 +1348,82 @@ def get_upload_url( |
1345 | 1348 | except ValueError as ve: |
1346 | 1349 | # Return a 400 error with the specific message from ValueError |
1347 | 1350 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1351 | + |
| 1352 | + |
| 1353 | +@router.get("/{inst_id}/add-custom-school-job/{job_run_id}") |
| 1354 | +def add_custom_school_job( |
| 1355 | + inst_id: str, |
| 1356 | + job_run_id: str, |
| 1357 | + model_name: str, |
| 1358 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1359 | + current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1360 | + databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)], |
| 1361 | +) -> Any: |
| 1362 | + """Fill in a JobTable .""" |
| 1363 | + has_access_to_inst_or_err(inst_id, current_user) |
| 1364 | + has_full_data_access_or_err(current_user, "this model") |
| 1365 | + local_session.set(sql_session) |
| 1366 | + |
| 1367 | + model_name = decode_url_piece(model_name) |
| 1368 | + inst_result = ( |
| 1369 | + local_session.get() |
| 1370 | + .execute( |
| 1371 | + select(InstTable).where( |
| 1372 | + and_( |
| 1373 | + InstTable.id == str_to_uuid(inst_id), |
| 1374 | + ) |
| 1375 | + ) |
| 1376 | + ) |
| 1377 | + .all() |
| 1378 | + ) |
| 1379 | + |
| 1380 | + query_result = ( |
| 1381 | + local_session.get() |
| 1382 | + .execute( |
| 1383 | + select(ModelTable).where( |
| 1384 | + and_( |
| 1385 | + ModelTable.name == model_name, |
| 1386 | + ModelTable.inst_id == str_to_uuid(inst_id), |
| 1387 | + ) |
| 1388 | + ) |
| 1389 | + ) |
| 1390 | + .all() |
| 1391 | + ) |
| 1392 | + |
| 1393 | + if not inst_result or not query_result: |
| 1394 | + raise HTTPException( |
| 1395 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1396 | + detail="Institution or model does not exist.", |
| 1397 | + ) |
| 1398 | + |
| 1399 | + try: |
| 1400 | + triggered_timestamp = datetime.now() |
| 1401 | + latest_model_version = databricks_control.fetch_model_version( |
| 1402 | + catalog_name=str(env_vars["CATALOG_NAME"]), |
| 1403 | + inst_name=inst_result[0][0].name, |
| 1404 | + model_name=model_name, |
| 1405 | + ) |
| 1406 | + job = JobTable( |
| 1407 | + id=job_run_id, |
| 1408 | + triggered_at=triggered_timestamp, |
| 1409 | + created_by=str_to_uuid(current_user.user_id), |
| 1410 | + batch_name="No batch name (manual custom school job)", |
| 1411 | + model_id=query_result[0][0].id, |
| 1412 | + output_valid=False, |
| 1413 | + model_version=latest_model_version.version, |
| 1414 | + model_run_id=latest_model_version.run_id, |
| 1415 | + ) |
| 1416 | + local_session.get().add(job) |
| 1417 | + |
| 1418 | + return { |
| 1419 | + "inst_id": inst_id, |
| 1420 | + "m_name": model_name, |
| 1421 | + "run_id": job_run_id, |
| 1422 | + "created_by": current_user.user_id, |
| 1423 | + "triggered_at": triggered_timestamp, |
| 1424 | + "model_version": latest_model_version.version, |
| 1425 | + "model_run_id": latest_model_version.run_id, |
| 1426 | + } |
| 1427 | + except ValueError as ve: |
| 1428 | + # Return a 400 error with the specific message from ValueError |
| 1429 | + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
0 commit comments