Skip to content

Commit d9fde4f

Browse files
committed
api: Add tracking ID validation
Validate the tracking ID in the API endpoints that require it, ensuring it's an alphanumeric string of length 1-256. The implementation and tests are based on MLflow's internal run ID validation: https://github.com/mlflow/mlflow/blob/92a1664ddbd7ef59f8db45e988e41437d179c3b1/mlflow/utils/validation.py#L374-L377 Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 59e7bb4 commit d9fde4f

File tree

8 files changed

+66
-13
lines changed

8 files changed

+66
-13
lines changed

app/api/dependencies.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import logging
2+
import re
3+
from typing import Union
4+
from typing_extensions import Annotated
5+
6+
from fastapi import HTTPException, Query
7+
from starlette.status import HTTP_400_BAD_REQUEST
28

39
from typing import Optional
410
from config import Settings
511
from registry import model_service_registry
612
from model_services.base import AbstractModelService
713
from management.model_manager import ModelManager
814

15+
TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$")
16+
917
logger = logging.getLogger("cms")
1018

1119

@@ -45,3 +53,14 @@ def __init__(self, model_service: AbstractModelService) -> None:
4553

4654
def __call__(self) -> ModelManager:
4755
return self._model_manager
56+
57+
58+
def validate_tracking_id(
59+
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the requested task")] = None,
60+
) -> Union[str, None]:
61+
if tracking_id is not None and TRACKING_ID_REGEX.match(tracking_id) is None:
62+
raise HTTPException(
63+
status_code=HTTP_400_BAD_REQUEST,
64+
detail=f"Invalid tracking ID '{tracking_id}', must be an alphanumeric string of length 1 to 256",
65+
)
66+
return tracking_id

app/api/routers/evaluation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fastapi.responses import StreamingResponse, JSONResponse
1212

1313
import api.globals as cms_globals
14+
from api.dependencies import validate_tracking_id
1415
from domain import Tags, Scope
1516
from model_services.base import AbstractModelService
1617
from processors.metrics_collector import (
@@ -34,7 +35,7 @@
3435
description="Evaluate the model being served with a trainer export")
3536
async def get_evaluation_with_trainer_export(request: Request,
3637
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
37-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the evaluation task")] = None,
38+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3839
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
3940
files = []
4041
file_names = []
@@ -70,7 +71,7 @@ async def get_evaluation_with_trainer_export(request: Request,
7071
description="Sanity check the model being served with a trainer export")
7172
def get_sanity_check_with_trainer_export(request: Request,
7273
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
73-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the sanity check task")] = None,
74+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
7475
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
7576
files = []
7677
file_names = []
@@ -106,7 +107,7 @@ def get_inter_annotator_agreement_scores(request: Request,
106107
annotator_a_project_id: Annotated[int, Query(description="The project ID from one annotator")],
107108
annotator_b_project_id: Annotated[int, Query(description="The project ID from another annotator")],
108109
scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")],
109-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the IAA task")] = None) -> StreamingResponse:
110+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse:
110111
files = []
111112
for te in trainer_export:
112113
temp_te = tempfile.NamedTemporaryFile()
@@ -143,7 +144,7 @@ def get_inter_annotator_agreement_scores(request: Request,
143144
description="Concatenate multiple trainer export files into a single file for download")
144145
def get_concatenated_trainer_exports(request: Request,
145146
trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")],
146-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the concatenation task")] = None) -> JSONResponse:
147+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> JSONResponse:
147148
files = []
148149
for te in trainer_export:
149150
temp_te = tempfile.NamedTemporaryFile()
@@ -167,7 +168,7 @@ def get_concatenated_trainer_exports(request: Request,
167168
description="Get annotation stats of trainer export files")
168169
def get_annotation_stats(request: Request,
169170
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
170-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the annotation stats task")] = None) -> StreamingResponse:
171+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse:
171172
files = []
172173
file_names = []
173174
for te in trainer_export:

app/api/routers/invocation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from domain import TextWithAnnotations, TextWithPublicKey, TextStreamItem, ModelCard, Tags
2121
from model_services.base import AbstractModelService
2222
from utils import get_settings
23+
from api.dependencies import validate_tracking_id
2324
from api.utils import get_rate_limiter, encrypt
2425
from management.prometheus_metrics import (
2526
cms_doc_annotations,
@@ -132,7 +133,7 @@ def get_entities_from_multiple_texts(request: Request,
132133
description="Upload a file containing a list of plain text and extract the NER entities in JSON")
133134
def extract_entities_from_multi_text_file(request: Request,
134135
multi_text_file: Annotated[UploadFile, File(description="A file containing a list of plain texts, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")],
135-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the bulk processing task")] = None,
136+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
136137
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
137138
with tempfile.NamedTemporaryFile() as data_file:
138139
for line in multi_text_file.file:

app/api/routers/metacat_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE
1111

1212
import api.globals as cms_globals
13+
from api.dependencies import validate_tracking_id
1314
from domain import Tags
1415
from model_services.base import AbstractModelService
1516
from processors.metrics_collector import concat_trainer_exports
@@ -29,7 +30,7 @@ async def train_metacat(request: Request,
2930
epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1,
3031
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
3132
description: Annotated[Union[str, None], Query(description="The description on the training or change logs")] = None,
32-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None,
33+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3334
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
3435
files = []
3536
file_names = []

app/api/routers/preview.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from starlette.status import HTTP_404_NOT_FOUND
1212

1313
import api.globals as cms_globals
14+
from api.dependencies import validate_tracking_id
1415
from domain import Doc, Tags
1516
from model_services.base import AbstractModelService
1617
from processors.metrics_collector import concat_trainer_exports
@@ -27,7 +28,7 @@
2728
description="Extract the NER entities in HTML for preview")
2829
async def get_rendered_entities_from_text(request: Request,
2930
text: Annotated[str, Body(description="The text to be sent to the model for NER", media_type="text/plain")],
30-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the preview task")] = None,
31+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3132
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
3233
annotations = model_service.annotate(text)
3334
entities = annotations_to_entities(annotations, model_service.model_name)
@@ -50,7 +51,7 @@ def get_rendered_entities_from_trainer_export(request: Request,
5051
trainer_export_str: Annotated[str, Form(description="The trainer export raw JSON string")] = "{\"projects\": []}",
5152
project_id: Annotated[Union[int, None], Query(description="The target project ID, and if not provided, all projects will be included")] = None,
5253
document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None,
53-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the trainer export preview task")] = None) -> Response:
54+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> Response:
5455
data: Dict = {"projects": []}
5556
if trainer_export is not None:
5657
files = []

app/api/routers/supervised_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE
1111

1212
import api.globals as cms_globals
13+
from api.dependencies import validate_tracking_id
1314
from domain import Tags
1415
from model_services.base import AbstractModelService
1516
from processors.metrics_collector import concat_trainer_exports
@@ -32,7 +33,7 @@ async def train_supervised(request: Request,
3233
test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage. (For a 'huggingface-ner' model, a negative value can be used to apply the train-validation-test split if implicitly defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for validation, and 'projects[2]' for testing)")] = 0.2,
3334
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
3435
description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None,
35-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None,
36+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3637
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
3738
files = []
3839
file_names = []

app/api/routers/unsupervised_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from fastapi.responses import JSONResponse
1313
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE
1414
import api.globals as cms_globals
15+
from api.dependencies import validate_tracking_id
1516
from domain import Tags, ModelType
1617
from model_services.base import AbstractModelService
1718
from utils import get_settings
@@ -33,7 +34,7 @@ async def train_unsupervised(request: Request,
3334
test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage", ge=0.0)] = 0.2,
3435
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000,
3536
description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None,
36-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None,
37+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3738
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
3839
"""
3940
Upload one or more plain text files and trigger the unsupervised training
@@ -97,7 +98,7 @@ async def train_unsupervised_with_hf_dataset(request: Request,
9798
test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage will only take effect if the dataset does not have predefined validation or test splits", ge=0.0)] = 0.2,
9899
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000,
99100
description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None,
100-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None,
101+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
101102
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
102103
"""
103104
Trigger the unsupervised training with a dataset from Hugging Face Hub

tests/app/api/test_dependencies.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from api.dependencies import ModelServiceDep
1+
import pytest
2+
from fastapi import HTTPException
3+
4+
from api.dependencies import ModelServiceDep, validate_tracking_id
25
from config import Settings
36
from model_services.medcat_model import MedCATModel
47
from model_services.medcat_model_icd10 import MedCATModelIcd10
@@ -36,3 +39,28 @@ def test_transformer_deid_dep():
3639
def test_huggingface_ner_dep():
3740
model_service_dep = ModelServiceDep("huggingface_ner", Settings())
3841
assert isinstance(model_service_dep(), HuggingFaceNerModel)
42+
43+
44+
@pytest.mark.parametrize(
45+
"run_id",
46+
[
47+
"a" * 32,
48+
"A" * 32,
49+
"a" * 256,
50+
"f0" * 16,
51+
"abcdef0123456789" * 2,
52+
"abcdefghijklmnopqrstuvqxyz",
53+
"123e4567-e89b-12d3-a456-426614174000",
54+
"123e4567e89b12d3a45642661417400",
55+
],
56+
)
57+
def test_validate_tracking_id(run_id):
58+
assert validate_tracking_id(run_id) == run_id
59+
60+
61+
@pytest.mark.parametrize("run_id", ["a/bc" * 8, "", "a" * 400, "*" * 5])
62+
def test_validate_tracking_id_invalid(run_id):
63+
with pytest.raises(HTTPException) as exc_info:
64+
validate_tracking_id(run_id)
65+
assert exc_info.value.status_code == 400
66+
assert "Invalid tracking ID" in exc_info.value.detail

0 commit comments

Comments
 (0)