Skip to content

Commit 5d94c50

Browse files
authored
feat: Allow passing a tracking ID for API requests with side-effects (#2)
Extend API to accept a tracking ID as an optional query parameter, allowing upstream systems to track training requests. Validate received IDs to ensure they're alphanumeric strings of length 1-256, following MLflow's internal run ID validation model. Extend serving tests to check that the ID (if provided) is included in the API's response. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 6d89586 commit 5d94c50

File tree

9 files changed

+223
-18
lines changed

9 files changed

+223
-18
lines changed

app/api/dependencies.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import logging
2+
import re
3+
from typing import Union
4+
from typing_extensions import Annotated
5+
6+
from fastapi import HTTPException, Query
7+
from starlette.status import HTTP_400_BAD_REQUEST
28

39
from typing import Optional
410
from config import Settings
511
from registry import model_service_registry
612
from model_services.base import AbstractModelService
713
from management.model_manager import ModelManager
814

15+
TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$")
16+
917
logger = logging.getLogger("cms")
1018

1119

@@ -45,3 +53,14 @@ def __init__(self, model_service: AbstractModelService) -> None:
4553

4654
def __call__(self) -> ModelManager:
4755
return self._model_manager
56+
57+
58+
def validate_tracking_id(
59+
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the requested task")] = None,
60+
) -> Union[str, None]:
61+
if tracking_id is not None and TRACKING_ID_REGEX.match(tracking_id) is None:
62+
raise HTTPException(
63+
status_code=HTTP_400_BAD_REQUEST,
64+
detail=f"Invalid tracking ID '{tracking_id}', must be an alphanumeric string of length 1 to 256",
65+
)
66+
return tracking_id

app/api/routers/evaluation.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import uuid
55
import tempfile
66

7-
from typing import List
7+
from typing import List, Union
88
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE
99
from typing_extensions import Annotated
1010
from fastapi import APIRouter, Query, Depends, UploadFile, Request, File
1111
from fastapi.responses import StreamingResponse, JSONResponse
1212

1313
import api.globals as cms_globals
14+
from api.dependencies import validate_tracking_id
1415
from domain import Tags, Scope
1516
from model_services.base import AbstractModelService
1617
from processors.metrics_collector import (
@@ -34,6 +35,7 @@
3435
description="Evaluate the model being served with a trainer export")
3536
async def get_evaluation_with_trainer_export(request: Request,
3637
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
38+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3739
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
3840
files = []
3941
file_names = []
@@ -54,7 +56,7 @@ async def get_evaluation_with_trainer_export(request: Request,
5456
json.dump(concatenated, data_file)
5557
data_file.flush()
5658
data_file.seek(0)
57-
evaluation_id = str(uuid.uuid4())
59+
evaluation_id = tracking_id or str(uuid.uuid4())
5860
evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names))
5961
if evaluation_accepted:
6062
return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED)
@@ -69,6 +71,7 @@ async def get_evaluation_with_trainer_export(request: Request,
6971
description="Sanity check the model being served with a trainer export")
7072
def get_sanity_check_with_trainer_export(request: Request,
7173
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
74+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
7275
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
7376
files = []
7477
file_names = []
@@ -88,8 +91,9 @@ def get_sanity_check_with_trainer_export(request: Request,
8891
metrics = sanity_check_model_with_trainer_export(concatenated, model_service, return_df=True, include_anchors=False)
8992
stream = io.StringIO()
9093
metrics.to_csv(stream, index=False)
94+
tracking_id = tracking_id or str(uuid.uuid4())
9195
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
92-
response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{str(uuid.uuid4())}.csv"'
96+
response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{tracking_id}.csv"'
9397
return response
9498

9599

@@ -102,7 +106,8 @@ def get_inter_annotator_agreement_scores(request: Request,
102106
trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")],
103107
annotator_a_project_id: Annotated[int, Query(description="The project ID from one annotator")],
104108
annotator_b_project_id: Annotated[int, Query(description="The project ID from another annotator")],
105-
scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")]) -> StreamingResponse:
109+
scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")],
110+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse:
106111
files = []
107112
for te in trainer_export:
108113
temp_te = tempfile.NamedTemporaryFile()
@@ -126,8 +131,9 @@ def get_inter_annotator_agreement_scores(request: Request,
126131
raise AnnotationException(f'Unknown scope: "{scope}"')
127132
stream = io.StringIO()
128133
iaa_scores.to_csv(stream, index=False)
134+
tracking_id = tracking_id or str(uuid.uuid4())
129135
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
130-
response.headers["Content-Disposition"] = f'attachment ; filename="iaa_{str(uuid.uuid4())}.csv"'
136+
response.headers["Content-Disposition"] = f'attachment ; filename="iaa_{tracking_id}.csv"'
131137
return response
132138

133139

@@ -137,7 +143,8 @@ def get_inter_annotator_agreement_scores(request: Request,
137143
dependencies=[Depends(cms_globals.props.current_active_user)],
138144
description="Concatenate multiple trainer export files into a single file for download")
139145
def get_concatenated_trainer_exports(request: Request,
140-
trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")]) -> JSONResponse:
146+
trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")],
147+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> JSONResponse:
141148
files = []
142149
for te in trainer_export:
143150
temp_te = tempfile.NamedTemporaryFile()
@@ -148,8 +155,9 @@ def get_concatenated_trainer_exports(request: Request,
148155
concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False)
149156
for file in files:
150157
file.close()
158+
tracking_id = tracking_id or str(uuid.uuid4())
151159
response = JSONResponse(concatenated, media_type="application/json; charset=utf-8")
152-
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{str(uuid.uuid4())}.json"'
160+
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"'
153161
return response
154162

155163

@@ -159,7 +167,8 @@ def get_concatenated_trainer_exports(request: Request,
159167
dependencies=[Depends(cms_globals.props.current_active_user)],
160168
description="Get annotation stats of trainer export files")
161169
def get_annotation_stats(request: Request,
162-
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")]) -> StreamingResponse:
170+
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
171+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse:
163172
files = []
164173
file_names = []
165174
for te in trainer_export:
@@ -177,6 +186,7 @@ def get_annotation_stats(request: Request,
177186
stats = get_stats_from_trainer_export(concatenated, return_df=True)
178187
stream = io.StringIO()
179188
stats.to_csv(stream, index=False)
189+
tracking_id = tracking_id or str(uuid.uuid4())
180190
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
181-
response.headers["Content-Disposition"] = f'attachment ; filename="stats_{str(uuid.uuid4())}.csv"'
191+
response.headers["Content-Disposition"] = f'attachment ; filename="stats_{tracking_id}.csv"'
182192
return response

app/api/routers/invocation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from domain import TextWithAnnotations, TextWithPublicKey, TextStreamItem, ModelCard, Tags
2121
from model_services.base import AbstractModelService
2222
from utils import get_settings
23+
from api.dependencies import validate_tracking_id
2324
from api.utils import get_rate_limiter, encrypt
2425
from management.prometheus_metrics import (
2526
cms_doc_annotations,
@@ -132,6 +133,7 @@ def get_entities_from_multiple_texts(request: Request,
132133
description="Upload a file containing a list of plain text and extract the NER entities in JSON")
133134
def extract_entities_from_multi_text_file(request: Request,
134135
multi_text_file: Annotated[UploadFile, File(description="A file containing a list of plain texts, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")],
136+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
135137
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
136138
with tempfile.NamedTemporaryFile() as data_file:
137139
for line in multi_text_file.file:
@@ -160,8 +162,9 @@ def extract_entities_from_multi_text_file(request: Request,
160162
output = json.dumps(body)
161163
logger.debug(output)
162164
json_file = BytesIO(output.encode())
165+
tracking_id = tracking_id or str(uuid.uuid4())
163166
response = StreamingResponse(json_file, media_type="application/json")
164-
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{str(uuid.uuid4())}.json"'
167+
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"'
165168
return response
166169

167170

app/api/routers/metacat_training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE
1111

1212
import api.globals as cms_globals
13+
from api.dependencies import validate_tracking_id
1314
from domain import Tags
1415
from model_services.base import AbstractModelService
1516
from processors.metrics_collector import concat_trainer_exports
@@ -29,6 +30,7 @@ async def train_metacat(request: Request,
2930
epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1,
3031
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
3132
description: Annotated[Union[str, None], Query(description="The description on the training or change logs")] = None,
33+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3234
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
3335
files = []
3436
file_names = []
@@ -49,7 +51,7 @@ async def train_metacat(request: Request,
4951
json.dump(concatenated, data_file)
5052
data_file.flush()
5153
data_file.seek(0)
52-
training_id = str(uuid.uuid4())
54+
training_id = tracking_id or str(uuid.uuid4())
5355
try:
5456
training_accepted = model_service.train_metacat(data_file,
5557
epochs,

app/api/routers/preview.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from starlette.status import HTTP_404_NOT_FOUND
1212

1313
import api.globals as cms_globals
14+
from api.dependencies import validate_tracking_id
1415
from domain import Doc, Tags
1516
from model_services.base import AbstractModelService
1617
from processors.metrics_collector import concat_trainer_exports
@@ -27,14 +28,16 @@
2728
description="Extract the NER entities in HTML for preview")
2829
async def get_rendered_entities_from_text(request: Request,
2930
text: Annotated[str, Body(description="The text to be sent to the model for NER", media_type="text/plain")],
31+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3032
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
3133
annotations = model_service.annotate(text)
3234
entities = annotations_to_entities(annotations, model_service.model_name)
3335
logger.debug("Entities extracted for previewing %s", entities)
3436
ent_input = Doc(text=text, ents=entities)
3537
data = displacy.render(ent_input.dict(), style="ent", manual=True)
38+
tracking_id = tracking_id or str(uuid.uuid4())
3639
response = StreamingResponse(BytesIO(data.encode()), media_type="application/octet-stream")
37-
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{str(uuid.uuid4())}.html"'
40+
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{tracking_id}.html"'
3841
return response
3942

4043

@@ -47,7 +50,8 @@ def get_rendered_entities_from_trainer_export(request: Request,
4750
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")] = [],
4851
trainer_export_str: Annotated[str, Form(description="The trainer export raw JSON string")] = "{\"projects\": []}",
4952
project_id: Annotated[Union[int, None], Query(description="The target project ID, and if not provided, all projects will be included")] = None,
50-
document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None) -> Response:
53+
document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None,
54+
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> Response:
5155
data: Dict = {"projects": []}
5256
if trainer_export is not None:
5357
files = []
@@ -88,8 +92,9 @@ def get_rendered_entities_from_trainer_export(request: Request,
8892
doc = Doc(text=document["text"], ents=entities, title=f"P{project['id']}/D{document['id']}")
8993
htmls.append(displacy.render(doc.dict(), style="ent", manual=True))
9094
if htmls:
95+
tracking_id = tracking_id or str(uuid.uuid4())
9196
response = StreamingResponse(BytesIO("<br/>".join(htmls).encode()), media_type="application/octet-stream")
92-
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{str(uuid.uuid4())}.html"'
97+
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{tracking_id}.html"'
9398
else:
9499
logger.debug("Cannot find any matching documents to preview")
95100
return JSONResponse(content={"message": "Cannot find any matching documents to preview"}, status_code=HTTP_404_NOT_FOUND)

app/api/routers/supervised_training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE
1111

1212
import api.globals as cms_globals
13+
from api.dependencies import validate_tracking_id
1314
from domain import Tags
1415
from model_services.base import AbstractModelService
1516
from processors.metrics_collector import concat_trainer_exports
@@ -32,6 +33,7 @@ async def train_supervised(request: Request,
3233
test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage. (For a 'huggingface-ner' model, a negative value can be used to apply the train-validation-test split if implicitly defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for validation, and 'projects[2]' for testing)")] = 0.2,
3334
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
3435
description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None,
36+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
3537
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
3638
files = []
3739
file_names = []
@@ -51,7 +53,7 @@ async def train_supervised(request: Request,
5153
json.dump(concatenated, data_file)
5254
data_file.flush()
5355
data_file.seek(0)
54-
training_id = str(uuid.uuid4())
56+
training_id = tracking_id or str(uuid.uuid4())
5557
try:
5658
training_accepted = model_service.train_supervised(data_file,
5759
epochs,

0 commit comments

Comments
 (0)