Skip to content

Commit ac0a8c1

Browse files
authored
train: Return MLflow tracking information (#4)
Extend the API to return MLflow tracking information as part of a training or evaluation response, including the experiment and run IDs, and update the tests accordingly. If training is already in progress, the API returns the experiment and run IDs of the current training run. This affects the following routes: * POST /train_supervised * POST /train_unsupervised * POST /train_unsupervised_with_hf_hub_dataset * POST /train_metacat * POST /evaluate Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 72bbe83 commit ac0a8c1

File tree

11 files changed

+122
-49
lines changed

11 files changed

+122
-49
lines changed

app/api/routers/evaluation.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,26 @@ async def get_evaluation_with_trainer_export(request: Request,
5757
data_file.flush()
5858
data_file.seek(0)
5959
evaluation_id = tracking_id or str(uuid.uuid4())
60-
evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names))
60+
evaluation_accepted, experiment_id, run_id = model_service.train_supervised(
61+
data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names)
62+
)
6163
if evaluation_accepted:
62-
return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED)
64+
return JSONResponse(
65+
content={
66+
"message": "Your evaluation started successfully.",
67+
"evaluation_id": evaluation_id,
68+
"experiment_id": experiment_id,
69+
"run_id": run_id,
70+
}, status_code=HTTP_202_ACCEPTED
71+
)
6372
else:
64-
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
73+
return JSONResponse(
74+
content={
75+
"message": "Another training or evaluation on this model is still active. Please retry later.",
76+
"experiment_id": experiment_id,
77+
"run_id": run_id,
78+
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
79+
)
6580

6681

6782
@router.post("/sanity-check",

app/api/routers/metacat_training.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid
33
import json
44
import logging
5-
from typing import List, Union
5+
from typing import List, Tuple, Union
66
from typing_extensions import Annotated
77

88
from fastapi import APIRouter, Depends, UploadFile, Query, Request, File
@@ -53,7 +53,7 @@ async def train_metacat(request: Request,
5353
data_file.seek(0)
5454
training_id = tracking_id or str(uuid.uuid4())
5555
try:
56-
training_accepted = model_service.train_metacat(data_file,
56+
training_response = model_service.train_metacat(data_file,
5757
epochs,
5858
log_frequency,
5959
training_id,
@@ -65,13 +65,27 @@ async def train_metacat(request: Request,
6565
for file in files:
6666
file.close()
6767

68-
return _get_training_response(training_accepted, training_id)
68+
return _get_training_response(training_response, training_id)
6969

7070

71-
def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse:
71+
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
72+
training_accepted, experiment_id, run_id = training_response
7273
if training_accepted:
7374
logger.debug("Training accepted with ID: %s", training_id)
74-
return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED)
75+
return JSONResponse(
76+
content={
77+
"message": "Your training started successfully.",
78+
"training_id": training_id,
79+
"experiment_id": experiment_id,
80+
"run_id": run_id,
81+
}, status_code=HTTP_202_ACCEPTED
82+
)
7583
else:
7684
logger.debug("Training refused due to another active training or evaluation on this model")
77-
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
85+
return JSONResponse(
86+
content={
87+
"message": "Another training or evaluation on this model is still active. Please retry your training later.",
88+
"experiment_id": experiment_id,
89+
"run_id": run_id,
90+
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
91+
)

app/api/routers/supervised_training.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid
33
import json
44
import logging
5-
from typing import List, Union
5+
from typing import List, Tuple, Union
66
from typing_extensions import Annotated
77

88
from fastapi import APIRouter, Depends, UploadFile, Query, Request, File, Form
@@ -55,7 +55,7 @@ async def train_supervised(request: Request,
5555
data_file.seek(0)
5656
training_id = tracking_id or str(uuid.uuid4())
5757
try:
58-
training_accepted = model_service.train_supervised(data_file,
58+
training_response = model_service.train_supervised(data_file,
5959
epochs,
6060
log_frequency,
6161
training_id,
@@ -69,13 +69,27 @@ async def train_supervised(request: Request,
6969
for file in files:
7070
file.close()
7171

72-
return _get_training_response(training_accepted, training_id)
72+
return _get_training_response(training_response, training_id)
7373

7474

75-
def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse:
75+
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
76+
training_accepted, experiment_id, run_id = training_response
7677
if training_accepted:
7778
logger.debug("Training accepted with ID: %s", training_id)
78-
return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED)
79+
return JSONResponse(
80+
content={
81+
"message": "Your training started successfully.",
82+
"training_id": training_id,
83+
"experiment_id": experiment_id,
84+
"run_id": run_id,
85+
}, status_code=HTTP_202_ACCEPTED
86+
)
7987
else:
8088
logger.debug("Training refused due to another active training or evaluation on this model")
81-
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
89+
return JSONResponse(
90+
content={
91+
"message": "Another training or evaluation on this model is still active. Please retry your training later.",
92+
"experiment_id": experiment_id,
93+
"run_id": run_id,
94+
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
95+
)

app/api/routers/unsupervised_training.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import datasets
77
import zipfile
8-
from typing import List, Union
8+
from typing import List, Tuple, Union
99
from typing_extensions import Annotated
1010

1111
from fastapi import APIRouter, Depends, UploadFile, Query, Request, File
@@ -65,7 +65,7 @@ async def train_unsupervised(request: Request,
6565
data_file.seek(0)
6666
training_id = tracking_id or str(uuid.uuid4())
6767
try:
68-
training_accepted = model_service.train_unsupervised(data_file,
68+
training_response = model_service.train_unsupervised(data_file,
6969
epochs,
7070
log_frequency,
7171
training_id,
@@ -79,7 +79,7 @@ async def train_unsupervised(request: Request,
7979
for file in files:
8080
file.close()
8181

82-
return _get_training_response(training_accepted, training_id)
82+
return _get_training_response(training_response, training_id)
8383

8484

8585
@router.post("/train_unsupervised_with_hf_hub_dataset",
@@ -133,7 +133,7 @@ async def train_unsupervised_with_hf_dataset(request: Request,
133133
hf_dataset.save_to_disk(data_dir.name)
134134

135135
training_id = tracking_id or str(uuid.uuid4())
136-
training_accepted = model_service.train_unsupervised(data_dir,
136+
training_response = model_service.train_unsupervised(data_dir,
137137
epochs,
138138
log_frequency,
139139
training_id,
@@ -143,13 +143,27 @@ async def train_unsupervised_with_hf_dataset(request: Request,
143143
lr_override=lr_override,
144144
test_size=test_size,
145145
description=description)
146-
return _get_training_response(training_accepted, training_id)
146+
return _get_training_response(training_response, training_id)
147147

148148

149-
def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse:
149+
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
150+
training_accepted, experiment_id, run_id = training_response
150151
if training_accepted:
151152
logger.debug("Training accepted with ID: %s", training_id)
152-
return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED)
153+
return JSONResponse(
154+
content={
155+
"message": "Your training started successfully.",
156+
"training_id": training_id,
157+
"experiment_id": experiment_id,
158+
"run_id": run_id,
159+
}, status_code=HTTP_202_ACCEPTED
160+
)
153161
else:
154162
logger.debug("Training refused due to another active training or evaluation on this model")
155-
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
163+
return JSONResponse(
164+
content={
165+
"message": "Another training or evaluation on this model is still active. Please retry later.",
166+
"experiment_id": experiment_id,
167+
"run_id": run_id,
168+
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
169+
)

app/model_services/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def batch_annotate(self, texts: List[str]) -> List[List[Dict[str, Any]]]:
5656
def init_model(self) -> None:
5757
raise NotImplementedError
5858

59-
def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool:
59+
def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]:
6060
raise NotImplementedError
6161

62-
def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool:
62+
def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]:
6363
raise NotImplementedError
6464

65-
def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool:
65+
def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]:
6666
raise NotImplementedError

app/model_services/huggingface_ner_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def train_supervised(self,
156156
raw_data_files: Optional[List[TextIO]] = None,
157157
description: Optional[str] = None,
158158
synchronised: bool = False,
159-
**hyperparams: Dict[str, Any]) -> bool:
159+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
160160
if self._supervised_trainer is None:
161161
raise ConfigurationException("The supervised trainer is not enabled")
162162
return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
@@ -170,7 +170,7 @@ def train_unsupervised(self,
170170
raw_data_files: Optional[List[TextIO]] = None,
171171
description: Optional[str] = None,
172172
synchronised: bool = False,
173-
**hyperparams: Dict[str, Any]) -> bool:
173+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
174174
if self._unsupervised_trainer is None:
175175
raise ConfigurationException("The unsupervised trainer is not enabled")
176176
return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)

app/model_services/medcat_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def train_supervised(self,
121121
raw_data_files: Optional[List[TextIO]] = None,
122122
description: Optional[str] = None,
123123
synchronised: bool = False,
124-
**hyperparams: Dict[str, Any]) -> bool:
124+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
125125
if self._supervised_trainer is None:
126126
raise ConfigurationException("The supervised trainer is not enabled")
127127
return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
@@ -135,7 +135,7 @@ def train_unsupervised(self,
135135
raw_data_files: Optional[List[TextIO]] = None,
136136
description: Optional[str] = None,
137137
synchronised: bool = False,
138-
**hyperparams: Dict[str, Any]) -> bool:
138+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
139139
if self._unsupervised_trainer is None:
140140
raise ConfigurationException("The unsupervised trainer is not enabled")
141141
return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
@@ -149,7 +149,7 @@ def train_metacat(self,
149149
raw_data_files: Optional[List[TextIO]] = None,
150150
description: Optional[str] = None,
151151
synchronised: bool = False,
152-
**hyperparams: Dict[str, Any]) -> bool:
152+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
153153
if self._metacat_trainer is None:
154154
raise ConfigurationException("The metacat trainer is not enabled")
155155
return self._metacat_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)

app/model_services/medcat_model_deid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
import threading
44
import torch
5-
from typing import Dict, List, TextIO, Optional, Any, final, Callable
5+
from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable
66
from functools import partial
77
from transformers import pipeline
88
from medcat.cat import CAT
@@ -147,7 +147,7 @@ def train_supervised(self,
147147
raw_data_files: Optional[List[TextIO]] = None,
148148
description: Optional[str] = None,
149149
synchronised: bool = False,
150-
**hyperparams: Dict[str, Any]) -> bool:
150+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
151151
if self._supervised_trainer is None:
152152
raise ConfigurationException("Trainers are not enabled")
153153
return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)

app/trainers/base.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from abc import ABC, abstractmethod
1010
from concurrent.futures import ThreadPoolExecutor
1111
from functools import partial
12-
from typing import TextIO, Callable, Dict, Optional, Any, List, Union, final
12+
from typing import TextIO, Callable, Dict, Tuple, Optional, Any, List, Union, final
1313
from config import Settings
1414
from management.tracker_client import TrackerClient
1515
from data import doc_dataset, anno_dataset
@@ -26,6 +26,8 @@ def __init__(self, config: Settings, model_name: str) -> None:
2626
self._model_name = model_name
2727
self._training_lock = threading.Lock()
2828
self._training_in_progress = False
29+
self._experiment_id = None
30+
self._run_id = None
2931
self._tracker_client = TrackerClient(self._config.MLFLOW_TRACKING_URI)
3032
self._executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(max_workers=1)
3133

@@ -37,6 +39,14 @@ def model_name(self) -> str:
3739
def model_name(self, model_name: str) -> None:
3840
self._model_name = model_name
3941

42+
@property
43+
def experiment_id(self) -> str:
44+
return self._experiment_id or ""
45+
46+
@property
47+
def run_id(self) -> str:
48+
return self._run_id or ""
49+
4050
@final
4151
def start_training(self,
4252
run: Callable,
@@ -48,13 +58,13 @@ def start_training(self,
4858
input_file_name: str,
4959
raw_data_files: Optional[List[TextIO]] = None,
5060
description: Optional[str] = None,
51-
synchronised: bool = False) -> bool:
61+
synchronised: bool = False) -> Tuple[bool, str, str]:
5262
with self._training_lock:
5363
if self._training_in_progress:
54-
return False
64+
return False, self.experiment_id, self.run_id
5565
else:
5666
loop = asyncio.get_event_loop()
57-
experiment_id, run_id = self._tracker_client.start_tracking(
67+
self._experiment_id, self._run_id = self._tracker_client.start_tracking(
5868
model_name=self._model_name,
5969
input_file_name=input_file_name,
6070
base_model_original=self._config.BASE_MODEL_FULL_PATH,
@@ -101,15 +111,15 @@ def start_training(self,
101111
else:
102112
raise ValueError(f"Unknown training type: {training_type}")
103113

104-
logger.info("Starting training job: %s with experiment ID: %s", training_id, experiment_id)
114+
logger.info("Starting training job: %s with experiment ID: %s", training_id, self.experiment_id)
105115
self._training_in_progress = True
106116
training_task = asyncio.ensure_future(loop.run_in_executor(self._executor,
107-
partial(run, self, training_params, data_file, log_frequency, run_id, description)))
117+
partial(run, self, training_params, data_file, log_frequency, self.run_id, description)))
108118

109119
if synchronised:
110120
loop.run_until_complete(training_task)
111121

112-
return True
122+
return True, self.experiment_id, self.run_id
113123

114124
@staticmethod
115125
def _make_model_file_copy(model_file_path: str, run_id: str) -> str:
@@ -161,7 +171,7 @@ def train(self,
161171
raw_data_files: Optional[List[TextIO]] = None,
162172
description: Optional[str] = None,
163173
synchronised: bool = False,
164-
**hyperparams: Dict[str, Any]) -> bool:
174+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
165175
training_type = TrainingType.SUPERVISED.value
166176
training_params = {
167177
"data_path": data_file.name,
@@ -204,7 +214,7 @@ def train(self,
204214
raw_data_files: Optional[List[TextIO]] = None,
205215
description: Optional[str] = None,
206216
synchronised: bool = False,
207-
**hyperparams: Dict[str, Any]) -> bool:
217+
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
208218
training_type = TrainingType.UNSUPERVISED.value
209219
training_params = {
210220
"nepochs": epochs,

0 commit comments

Comments
 (0)