Skip to content

Commit 065fd60

Browse files
committed
add job table
1 parent a0ed749 commit 065fd60

File tree

4 files changed

+90
-34
lines changed

4 files changed

+90
-34
lines changed

src/webapp/database.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Text,
1818
JSON,
1919
Integer,
20+
BigInteger,
2021
)
2122
from typing import Set, List
2223
from sqlalchemy.orm import sessionmaker, Session, relationship, mapped_column, Mapped
@@ -331,12 +332,11 @@ class ModelTable(Base):
331332
)
332333
inst: Mapped["InstTable"] = relationship(back_populates="models")
333334

335+
jobs: Mapped[Set["JobTable"]] = relationship(back_populates="model")
336+
334337
name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False)
335338
# What configuration of schemas are allowed (list of maps e.g. [PDP Course : 1 + PDP Cohort : 1, X_schema :1 + Y_schema: 2])
336339
schema_configs = Column(MutableList.as_mutable(JSON), nullable=True)
337-
# A list of all the runs executed using this model. These ids will correspond to Databricks ids so that we can retrieve things like
338-
# status and correlate output using Databricks.
339-
run_ids = Column(MutableList.as_mutable(JSON), nullable=True)
340340
created_by = Column(Uuid(as_uuid=True), nullable=True)
341341
# If null, the following is non-deleted.
342342
deleted: Mapped[bool] = mapped_column(nullable=True)
@@ -355,6 +355,28 @@ class ModelTable(Base):
355355
)
356356

357357

358+
class JobTable(Base):
359+
__tablename__ = "job"
360+
id = Column(BigInteger, primary_key=True)
361+
362+
# Set the parent foreign key to link to the institution table.
363+
model_id = Column(
364+
Uuid(as_uuid=True),
365+
ForeignKey("model.id", ondelete="CASCADE"),
366+
nullable=False,
367+
)
368+
model: Mapped["ModelTable"] = relationship(back_populates="jobs")
369+
370+
created_by = Column(Uuid(as_uuid=True), nullable=False)
371+
# The time the deletion request was set.
372+
triggered_at = Column(DateTime(timezone=True), nullable=False)
373+
batch_name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False)
374+
# The following will be empty if not completed or if job errored out. Getting additional details will require a call to the Databricks table.
375+
output_filename = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True)
376+
err_msg = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True)
377+
completed: Mapped[bool] = mapped_column(nullable=True)
378+
379+
358380
"""
359381
def get_one_record(sess_context_var: ContextVar, sess: Session, select_query: ) -> Any:
360382
local_session.set(sql_session)

src/webapp/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ async def set_datakinders(
203203
status_code=status.HTTP_401_UNAUTHORIZED,
204204
detail="Only Datakinders can set other Datakinders",
205205
)
206+
# TODO xxx check the user doesn't have an inst first
206207
local_session.set(sql_session)
207208
local_session.get().execute(
208209
update(AccountTable)

src/webapp/routers/models.py

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"""
33

44
import uuid
5-
65
from typing import Annotated, Any, Tuple, Dict
76
from fastapi import APIRouter, Depends, HTTPException, status
87
from pydantic import BaseModel
@@ -30,6 +29,7 @@
3029
FileTable,
3130
InstTable,
3231
ModelTable,
32+
JobTable,
3333
)
3434

3535
from ..gcsutil import StorageControl
@@ -59,6 +59,7 @@ class ModelCreationRequest(BaseModel):
5959
class ModelInfo(BaseModel):
6060
"""The model object that's returned."""
6161

62+
# The model id is unique for every instance of the model (e.g. model name + version id pair)
6263
m_id: str
6364
name: str
6465
inst_id: str
@@ -76,15 +77,23 @@ class RunInfo(BaseModel):
7677
run_id: str
7778
vers_id: int = 0
7879
inst_id: str
79-
m_id: str
80+
m_name: str
8081
# user id of the person who executed this run.
8182
created_by: str | None = None
83+
# Time the run info was triggered if it was triggered in the webapp
84+
triggered_at: datetime | None = None
85+
# Batch used for the run
86+
batch_name: str | None = None
87+
# output file name
88+
output_filename: str | None = None
89+
completed: bool | None = None
90+
err_msg: str | None = None
8291

8392

8493
class InferenceRunRequest(BaseModel):
8594
"""Parameters for an inference run."""
8695

87-
batch_id: str
96+
batch_name: str
8897

8998

9099
# Model related operations. Or model specific data.
@@ -357,18 +366,21 @@ def read_inst_model_outputs(
357366
detail="Multiple models of the same version found, this should not have happened.",
358367
)
359368
res = []
360-
if query_result[0][0].run_ids:
361-
for elem in query_result[0][0].run_ids:
362-
res.append(
363-
{
364-
# xxxx
365-
"run_id": "placeholder",
366-
"inst_id": uuid_to_str(query_result[0][0].inst_id),
367-
"m_id": uuid_to_str(query_result[0][0].id),
368-
"created_by": "placeholder",
369-
"vers_id": vers_id,
370-
}
371-
)
369+
for elem in query_result[0][0].jobs or []:
370+
# if not elem.output_filename:
371+
# TODO make a query to databricks to retrieve status.
372+
res.append(
373+
{
374+
"vers_id": vers_id,
375+
"inst_id": uuid_to_str(query_result[0][0].inst_id),
376+
"m_name": query_result[0][0].name,
377+
"run_id": elem.id,
378+
"created_by": uuid_to_str(elem.created_by),
379+
"triggered_at": elem.triggered_at,
380+
"batch_name": elem.batch_name,
381+
"output_filename": elem.output_filename,
382+
}
383+
)
372384
return res
373385

374386

@@ -417,21 +429,27 @@ def read_inst_model_output(
417429
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
418430
detail="Multiple models of the same version found, this should not have happened.",
419431
)
420-
if query_result[0][0].run_ids:
421-
# TODO xxx: check the run id is present, then make a query to Databricks
422-
return {
423-
"run_id": "placeholder",
424-
"inst_id": uuid_to_str(query_result[0][0].inst_id),
425-
"m_id": uuid_to_str(query_result[0][0].id),
426-
"created_by": "placeholder",
427-
"vers_id": vers_id,
428-
}
432+
433+
for elem in query_result[0][0].jobs or []:
434+
if elem.id == run_id:
435+
# TODO xxx: if the output_filename is empty make a query to Databricks
436+
return {
437+
"vers_id": vers_id,
438+
"inst_id": uuid_to_str(query_result[0][0].inst_id),
439+
"m_name": query_result[0][0].name,
440+
"run_id": elem.id,
441+
"created_by": uuid_to_str(elem.created_by),
442+
"triggered_at": elem.triggered_at,
443+
"batch_name": elem.batch_name,
444+
"output_filename": elem.output_filename,
445+
}
429446
raise HTTPException(
430447
status_code=status.HTTP_404_NOT_FOUND,
431448
detail="Run not found.",
432449
)
433450

434451

452+
# TODO: xxx update the run info returned items.
435453
@router.post(
436454
"/{inst_id}/models/{model_name}/vers/{vers_id}/run-inference",
437455
response_model=RunInfo,
@@ -476,13 +494,28 @@ def trigger_inference_run(
476494
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
477495
detail="Multiple models of the same version found, this should not have happened.",
478496
)
479-
# TODO issue Databricks call
480-
return { # xxxx
481-
"run_id": "placeholder",
482-
"inst_id": uuid_to_str(query_result[0][0].inst_id),
483-
"m_id": uuid_to_str(query_result[0][0].id),
484-
"created_by": "placeholder",
497+
# TODO issue Databricks call then use the result to populate the below
498+
499+
triggered_timestamp = datetime.now()
500+
# Add an entry to the jobs table with the job id
501+
"""
502+
job = JobTable(
503+
id=12345,
504+
triggered_at=triggered_timestamp,
505+
created_by=str_to_uuid(current_user.user_id),
506+
batch_name=req.batch_name,
507+
model_id=query_result[0][0].id,
508+
)
509+
local_session.get().add(job)
510+
"""
511+
return {
485512
"vers_id": vers_id,
513+
"inst_id": uuid_to_str(query_result[0][0].inst_id),
514+
"m_name": query_result[0][0].name,
515+
"run_id": "placeholder",
516+
"created_by": current_user.user_id,
517+
"triggered_at": triggered_timestamp,
518+
"batch_name": req.batch_name,
486519
}
487520

488521

src/webapp/routers/models_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def test_trigger_inference_run(client: TestClient):
286286
+ uuid_to_str(USER_VALID_INST_UUID)
287287
+ "/models/sample_model_for_school_1/vers/0/run-inference",
288288
json={
289-
"batch_id": "abc",
289+
"batch_name": "abc",
290290
},
291291
)
292292

0 commit comments

Comments
 (0)