Skip to content

Commit fadcdde

Browse files
itallixleoll2
andauthored
Training job: prepare OTX training config (#5028)
Co-authored-by: Leonardo Lai <[email protected]>
1 parent 0db4e67 commit fadcdde

File tree

6 files changed

+145
-57
lines changed

6 files changed

+145
-57
lines changed

application/backend/app/lifecycle.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from app.db import MigrationManager, get_db_session
2020
from app.scheduler import Scheduler
2121
from app.schemas.job import JobType
22-
from app.services import DatasetService, LabelService, ModelService
22+
from app.services import DatasetService, LabelService, ModelService, TrainingConfigurationService
2323
from app.services.base_weights_service import BaseWeightsService
2424
from app.services.data_collect import DataCollector
2525
from app.services.event.event_bus import EventBus
2626
from app.services.training import OTXTrainer
27+
from app.services.training.otx_trainer import TrainingDependencies
2728
from app.services.training.subset_assignment import SubsetAssigner, SubsetService
2829
from app.settings import get_settings
2930
from app.webrtc.manager import WebRTCManager
@@ -46,23 +47,20 @@ def setup_job_controller(data_dir: Path, max_parallel_jobs: int) -> tuple[JobQue
4647
"""
4748
q = JobQueue()
4849
job_runnable_factory = RunnableFactory[JobType, Runnable]()
49-
base_weights_service = BaseWeightsService(data_dir=data_dir)
50-
subset_service = SubsetService()
51-
subset_assigner = SubsetAssigner()
52-
label_service = LabelService()
53-
model_service = ModelService()
54-
dataset_service = DatasetService(data_dir=data_dir, label_service=label_service)
5550
job_runnable_factory.register(
5651
JobType.TRAIN,
5752
partial(
5853
OTXTrainer,
59-
base_weights_service=base_weights_service,
60-
subset_service=subset_service,
61-
subset_assigner=subset_assigner,
62-
dataset_service=dataset_service,
63-
model_service=model_service,
64-
data_dir=data_dir,
65-
db_session_factory=get_db_session,
54+
training_deps=TrainingDependencies(
55+
base_weights_service=BaseWeightsService(data_dir=data_dir),
56+
subset_service=SubsetService(),
57+
subset_assigner=SubsetAssigner(),
58+
dataset_service=DatasetService(data_dir=data_dir, label_service=LabelService()),
59+
model_service=ModelService(),
60+
training_configuration_service=TrainingConfigurationService(),
61+
data_dir=data_dir,
62+
db_session_factory=get_db_session,
63+
),
6664
),
6765
)
6866
process_runner_factory = ProcessRunnerFactory(job_runnable_factory)

application/backend/app/services/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .sink_service import SinkService
2323
from .source_service import SourceService, SourceUpdateService
2424
from .system_service import SystemService
25+
from .training_configuration_service import TrainingConfigurationService
2526
from .video_stream_service import VideoStreamService
2627

2728
__all__ = [
@@ -46,5 +47,6 @@
4647
"SourceService",
4748
"SourceUpdateService",
4849
"SystemService",
50+
"TrainingConfigurationService",
4951
"VideoStreamService",
5052
]

application/backend/app/services/training/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from uuid import UUID, uuid4
77

88
from loguru import logger
9+
from pydantic import Field
910

1011
from app.core.jobs import Job, JobParams, JobType
1112
from app.schemas.project import TaskBase
@@ -17,7 +18,7 @@ class TrainingParams(JobParams):
1718
model_architecture_id: str
1819
parent_model_revision_id: UUID | None = None
1920
task: TaskBase
20-
model_id: UUID = uuid4() # Reserve the ID for the model to be created for this training job
21+
model_id: UUID = Field(default_factory=uuid4)
2122

2223

2324
class ProjectJob(Job):

application/backend/app/services/training/otx_trainer.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,26 @@
88
from pathlib import Path
99
from uuid import UUID
1010

11+
import yaml
1112
from datumaro.experimental import Dataset
1213
from datumaro.experimental.fields import Subset
1314
from loguru import logger
15+
from otx.types.task import OTXTaskType
1416
from sqlalchemy.orm import Session
1517

18+
from app.core.jobs import JobType
1619
from app.core.run import ExecutionContext
17-
from app.models import DatasetItemAnnotationStatus
20+
from app.models import DatasetItemAnnotationStatus, TaskType
21+
from app.models.training_configuration.configuration import TrainingConfiguration
1822
from app.schemas.model import TrainingStatus
1923
from app.schemas.project import TaskBase
20-
from app.services import BaseWeightsService, DatasetService, ModelRevisionMetadata, ModelService
24+
from app.services import (
25+
BaseWeightsService,
26+
DatasetService,
27+
ModelRevisionMetadata,
28+
ModelService,
29+
TrainingConfigurationService,
30+
)
2131

2232
from .base import Trainer, step
2333
from .models import TrainingParams
@@ -26,6 +36,23 @@
2636
MODEL_WEIGHTS_PATH = "model_weights_path"
2737

2838

39+
# TODO: Consider adopting some lightweight DI framework
40+
# As the number of constructor dependencies grows and start violating ruff rules, we should evaluate DI frameworks like:
41+
# - dependency-injector (https://python-dependency-injector.ets-labs.org/)
42+
# - injector (https://github.com/python-injector/injector)
43+
# - python-inject (https://github.com/ivankorobkov/python-inject)
44+
@dataclass(frozen=True)
45+
class TrainingDependencies:
46+
data_dir: Path
47+
base_weights_service: BaseWeightsService
48+
subset_service: SubsetService
49+
dataset_service: DatasetService
50+
model_service: ModelService
51+
training_configuration_service: TrainingConfigurationService
52+
subset_assigner: SubsetAssigner
53+
db_session_factory: Callable[[], AbstractContextManager[Session]]
54+
55+
2956
@dataclass(frozen=True)
3057
class DatasetInfo:
3158
training: Dataset
@@ -39,22 +66,17 @@ class OTXTrainer(Trainer):
3966

4067
def __init__(
4168
self,
42-
data_dir: Path,
43-
base_weights_service: BaseWeightsService,
44-
subset_service: SubsetService,
45-
dataset_service: DatasetService,
46-
model_service: ModelService,
47-
subset_assigner: SubsetAssigner,
48-
db_session_factory: Callable[[], AbstractContextManager[Session]],
69+
training_deps: TrainingDependencies,
4970
):
5071
super().__init__()
51-
self._data_dir = data_dir
52-
self._base_weights_service = base_weights_service
53-
self._subset_service = subset_service
54-
self._dataset_service = dataset_service
55-
self._model_service = model_service
56-
self._subset_assigner = subset_assigner
57-
self._db_session_factory = db_session_factory
72+
self._data_dir = training_deps.data_dir
73+
self._base_weights_service = training_deps.base_weights_service
74+
self._subset_service = training_deps.subset_service
75+
self._dataset_service = training_deps.dataset_service
76+
self._model_service = training_deps.model_service
77+
self._training_configuration_service = training_deps.training_configuration_service
78+
self._subset_assigner = training_deps.subset_assigner
79+
self._db_session_factory = training_deps.db_session_factory
5880

5981
@step("Prepare Model Weights")
6082
def prepare_weights(self, training_params: TrainingParams) -> Path:
@@ -127,19 +149,28 @@ def create_training_dataset(self, project_id: UUID, task: TaskBase) -> DatasetIn
127149
revision_id=self._dataset_service.save_revision(project_id, dm_dataset),
128150
)
129151

130-
@step("Prepare Model Metadata")
152+
@step("Prepare Model and Training Configuration")
131153
def prepare_model(self, training_params: TrainingParams, dataset_revision_id: UUID) -> None:
132154
if training_params.project_id is None:
133155
raise ValueError("Project ID must be provided for model preparation")
134156
with self._db_session_factory() as db:
157+
self._training_configuration_service.set_db_session(db)
135158
self._model_service.set_db_session(db)
159+
configuration = self._training_configuration_service.get_training_configuration(
160+
project_id=training_params.project_id,
161+
model_architecture_id=training_params.model_architecture_id,
162+
)
163+
config_path = self.__build_model_config_path(
164+
self._data_dir, training_params.project_id, training_params.model_id
165+
)
166+
self.__persist_configuration(configuration, config_path, training_params.task)
136167
self._model_service.create_revision(
137168
ModelRevisionMetadata(
138169
model_id=training_params.model_id,
139170
project_id=training_params.project_id,
140171
architecture_id=training_params.model_architecture_id,
141172
parent_revision_id=training_params.parent_model_revision_id,
142-
training_configuration=None, # TODO: to be set when config is added
173+
training_configuration=configuration,
143174
dataset_revision_id=dataset_revision_id,
144175
training_status=TrainingStatus.NOT_STARTED,
145176
)
@@ -172,5 +203,31 @@ def run(self, ctx: ExecutionContext) -> None:
172203
self.train_model(training_params)
173204

174205
@staticmethod
175-
def __build_model_weights_path(data_dir: Path, project_id: UUID, model_id: UUID) -> Path:
176-
return data_dir / "projects" / str(project_id) / "models" / str(model_id) / "model.pth"
206+
def __base_model_path(data_dir: Path, project_id: UUID, model_id: UUID) -> Path:
207+
return data_dir / "projects" / str(project_id) / "models" / str(model_id)
208+
209+
@classmethod
210+
def __build_model_weights_path(cls, data_dir: Path, project_id: UUID, model_id: UUID) -> Path:
211+
return cls.__base_model_path(data_dir, project_id, model_id) / "model.pth"
212+
213+
@classmethod
214+
def __build_model_config_path(cls, data_dir: Path, project_id: UUID, model_id: UUID) -> Path:
215+
return cls.__base_model_path(data_dir, project_id, model_id) / "config.yaml"
216+
217+
@staticmethod
218+
def __persist_configuration(configuration: TrainingConfiguration, config_path: Path, task: TaskBase) -> None:
219+
extended_config = configuration.model_dump(exclude_none=True)
220+
extended_config["job_type"] = JobType.TRAIN.value
221+
match task.task_type:
222+
case TaskType.CLASSIFICATION:
223+
if task.exclusive_labels:
224+
extended_config["sub_task_type"] = OTXTaskType.MULTI_CLASS_CLS.value
225+
else:
226+
extended_config["sub_task_type"] = OTXTaskType.MULTI_LABEL_CLS.value
227+
case TaskType.DETECTION:
228+
extended_config["sub_task_type"] = OTXTaskType.DETECTION.value
229+
case TaskType.INSTANCE_SEGMENTATION:
230+
extended_config["sub_task_type"] = OTXTaskType.INSTANCE_SEGMENTATION.value
231+
config_path.parent.mkdir(parents=True, exist_ok=True)
232+
with open(config_path, "w") as f:
233+
yaml.dump(extended_config, f, default_flow_style=False)

application/backend/app/services/training_configuration_service.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77
from app.models.training_configuration.configuration import PartialTrainingConfiguration, TrainingConfiguration
88
from app.repositories import ModelRevisionRepository, ProjectRepository
99
from app.repositories.training_configuration_repo import TrainingConfigurationRepository
10-
from app.services import ResourceNotFoundError, ResourceType
10+
from app.services import BaseSessionManagedService, ResourceNotFoundError, ResourceType
1111
from app.services.tools import ConfigurationOverlayTools
1212
from app.supported_models import SupportedModels
1313
from app.supported_models.default_models import DefaultModels
1414

1515

16-
class TrainingConfigurationService:
17-
def __init__(self, db_session: Session) -> None:
18-
self._db_session = db_session
19-
self._training_config_repo = TrainingConfigurationRepository(db_session)
16+
class TrainingConfigurationService(BaseSessionManagedService):
17+
def __init__(self, db_session: Session | None = None) -> None:
18+
super().__init__(db_session)
2019

2120
def get_training_configuration(
2221
self,
@@ -63,7 +62,7 @@ def _get_by_model_revision_id(self, project_id: UUID, model_revision_id: UUID) -
6362
Returns:
6463
TrainingConfiguration: The training configuration object.
6564
"""
66-
model = ModelRevisionRepository(str(project_id), self._db_session).get_by_id(str(model_revision_id))
65+
model = ModelRevisionRepository(str(project_id), self.db_session).get_by_id(str(model_revision_id))
6766
if not model:
6867
raise ResourceNotFoundError(ResourceType.MODEL, str(model_revision_id))
6968
return TrainingConfiguration.model_validate(model.training_configuration)
@@ -79,7 +78,7 @@ def _get_by_model_architecture_id(self, project_id: UUID, model_architecture_id:
7978
Returns:
8079
TrainingConfiguration: The training configuration object.
8180
"""
82-
stored_config = TrainingConfigurationRepository(self._db_session).get_by_project_and_model_architecture(
81+
stored_config = TrainingConfigurationRepository(self.db_session).get_by_project_and_model_architecture(
8382
project_id=str(project_id),
8483
model_architecture_id=model_architecture_id,
8584
)
@@ -102,7 +101,7 @@ def _get_default_configuration(self, project_id: UUID) -> TrainingConfiguration:
102101
Returns:
103102
TrainingConfiguration: The default training configuration object.
104103
"""
105-
project = ProjectRepository(self._db_session).get_by_id(str(project_id))
104+
project = ProjectRepository(self.db_session).get_by_id(str(project_id))
106105
if not project:
107106
raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id))
108107

@@ -133,7 +132,7 @@ def update_training_configuration(
133132
Returns:
134133
TrainingConfiguration: The updated training configuration object.
135134
"""
136-
project = ProjectRepository(self._db_session).get_by_id(str(project_id))
135+
project = ProjectRepository(self.db_session).get_by_id(str(project_id))
137136
if not project:
138137
raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id))
139138

@@ -150,7 +149,8 @@ def update_training_configuration(
150149

151150
validated_updated_config = PartialTrainingConfiguration(**updated_config) # type: ignore[arg-type]
152151

153-
self._training_config_repo.create_or_update(
152+
training_config_repo = TrainingConfigurationRepository(self.db_session)
153+
training_config_repo.create_or_update(
154154
project_id=str(project_id),
155155
model_architecture_id=model_architecture_id,
156156
configuration_data=validated_updated_config.model_dump(),

0 commit comments

Comments
 (0)