diff --git a/src/apiserver/api_models.py b/src/apiserver/api_models.py index bbbc85de..8cb98334 100644 --- a/src/apiserver/api_models.py +++ b/src/apiserver/api_models.py @@ -1,5 +1,4 @@ from __future__ import annotations -import uuid import pydantic import typing import db_models @@ -11,17 +10,14 @@ class ApiVersion(pydantic.BaseModel): class Run(pydantic.BaseModel): - id: uuid.UUID - """ - Run identifier, unique in the system. - """ + id: str toolchain_name: str problem_name: str - user_id: uuid.UUID + user_id: str contest_name: str status: typing.Mapping[str, str] = pydantic.Field(default_factory=dict) @staticmethod def from_db(doc: db_models.RunMainProj) -> Run: - return Run(id=doc['id'], toolchain_name=doc['toolchain_name'], - user_id=doc['user_id'], contest_name=doc['contest_name'], problem_name=doc['problem_name'], status=doc['status']) + return Run(id=str(doc['_id']), toolchain_name=doc['toolchain_name'], + user_id=str(doc['user_id']), contest_name=doc['contest_name'], problem_name=doc['problem_name'], status=doc['status']) diff --git a/src/apiserver/db_models.py b/src/apiserver/db_models.py index c07e3aa0..95924344 100644 --- a/src/apiserver/db_models.py +++ b/src/apiserver/db_models.py @@ -1,10 +1,28 @@ -import uuid +from bson import ObjectId from enum import Enum -import time import typing from pydantic import BaseModel, Field +class Id(ObjectId): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + if not isinstance(v, ObjectId): + raise TypeError('ObjectId required') + return str(v) + + @classmethod + def __modify_schema__(cls, schema): + schema.update({ + 'Title': 'Object ID', + 'type': 'string' + }) + + class RunPhase(Enum): """ # QUEUED @@ -24,10 +42,9 @@ class RunPhase(Enum): class RunMainProj(BaseModel): - id: uuid.UUID toolchain_name: str problem_name: str - user_id: uuid.UUID + user_id: Id contest_name: str phase: str # RunPhase status: typing.Mapping[str, str] = Field(default_factory=dict) @@ -37,7 +54,7 @@ class RunMainProj(BaseModel): """ -RunMainProj.FIELDS = ['id', 'toolchain_name', +RunMainProj.FIELDS = ['toolchain_name', 'problem_name', 'user_id', 'contest_name', 'status'] diff --git a/src/apiserver/routes.py b/src/apiserver/routes.py index 0f2e3c6b..556eb142 100644 --- a/src/apiserver/routes.py +++ b/src/apiserver/routes.py @@ -1,10 +1,10 @@ import fastapi import db_models import api_models -import uuid import typing import base64 import pymongo +from bson import ObjectId import pydantic @@ -97,15 +97,14 @@ def route_submit(params: RunSubmitSimpleParams, db: pymongo.database.Database = fields of request body; `id` will be real id of this run. """ - run_uuid = uuid.uuid4() - user_id = uuid.UUID('12345678123456781234567812345678') - doc_main = db_models.RunMainProj(id=run_uuid, toolchain_name=params.toolchain, + user_id = ObjectId('507f1f77bcf86cd799439011') + doc_main = db_models.RunMainProj(toolchain_name=params.toolchain, problem_name=params.problem, user_id=user_id, contest_name=params.contest, phase=str(db_models.RunPhase.QUEUED)) doc_source = db_models.RunSourceProj( source=base64.b64decode(params.code)) doc = {**dict(doc_main), **dict(doc_source)} - db.runs.insert_one(doc) - return api_models.Run(id=run_uuid, toolchain_name=params.toolchain, problem_name=params.problem, user_id=user_id, contest_name=params.contest) + result = db.runs.insert_one(doc) + return api_models.Run(id=str(result.inserted_id), toolchain_name=params.toolchain, problem_name=params.problem, user_id=str(user_id), contest_name=params.contest) @app.get('/runs', response_model=typing.List[api_models.Run], operation_id='listRuns') @@ -122,13 +121,12 @@ def route_list_runs(db: pymongo.database.Database = fastapi.Depends(db_connect)) return runs @app.get('/runs/{run_id}', response_model=api_models.Run, operation_id='getRun') - def route_get_run(run_id: uuid.UUID, db: pymongo.database.Database = fastapi.Depends(db_connect)): + def route_get_run(run_id: str, db: pymongo.database.Database = fastapi.Depends(db_connect)): """ Loads run by id """ - run = db.runs.find_one(projection=db_models.RunMainProj.FIELDS, filter={ - 'id': run_id + '_id': ObjectId(run_id) }) if run is None: raise fastapi.HTTPException(404, detail='RunNotFound') @@ -139,13 +137,13 @@ def route_get_run(run_id: uuid.UUID, db: pymongo.database.Database = fastapi.Dep 'description': "Run source is not available" } }) - def route_get_run_source(run_id: uuid.UUID, db: pymongo.database.Database = fastapi.Depends(db_connect)): + def route_get_run_source(run_id: str, db: pymongo.database.Database = fastapi.Depends(db_connect)): """ Returns run source as base64-encoded JSON string """ doc = db.runs.find_one(projection=['source'], filter={ - 'id': run_id + '_id': ObjectId(run_id) }) if doc is None: raise fastapi.HTTPException(404, detail='RunNotFound') @@ -154,7 +152,7 @@ def route_get_run_source(run_id: uuid.UUID, db: pymongo.database.Database = fast return base64.b64encode(doc['source']) @app.patch('/runs/{run_id}', response_model=api_models.Run, operation_id='patchRun') - def route_run_patch(run_id: uuid.UUID, patch: RunPatch, db: pymongo.database.Database = fastapi.Depends(db_connect)): + def route_run_patch(run_id: str, patch: RunPatch, db: pymongo.database.Database = fastapi.Depends(db_connect)): """ Modifies existing run @@ -177,7 +175,7 @@ def route_run_patch(run_id: uuid.UUID, patch: RunPatch, db: pymongo.database.Dat "RunPatch.status[*] must have length exactly 2") p['$set'][f"status.{status_to_add[0]}"] = status_to_add[1] updated_run = db.runs.find_one_and_update( - {'id': run_id}, p, projection=db_models.RunMainProj.FIELDS, return_document=pymongo.ReturnDocument.AFTER) + {'_id': ObjectId(run_id)}, p, projection=db_models.RunMainProj.FIELDS, return_document=pymongo.ReturnDocument.AFTER) if updated_run is None: raise fastapi.HTTPException(404, 'RunNotFound') return updated_run