Skip to content

Commit fcbaf56

Browse files
authored
Save initial model revision in a training job (#5024)
1 parent 4dd5e77 commit fcbaf56

File tree

7 files changed

+211
-129
lines changed

7 files changed

+211
-129
lines changed

application/backend/app/lifecycle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from app.db import MigrationManager, get_db_session
2020
from app.scheduler import Scheduler
2121
from app.schemas.job import JobType
22-
from app.services import DatasetService, LabelService
22+
from app.services import DatasetService, LabelService, ModelService
2323
from app.services.base_weights_service import BaseWeightsService
2424
from app.services.data_collect import DataCollector
2525
from app.services.event.event_bus import EventBus
@@ -50,6 +50,7 @@ def setup_job_controller(data_dir: Path, max_parallel_jobs: int) -> tuple[JobQue
5050
subset_service = SubsetService()
5151
subset_assigner = SubsetAssigner()
5252
label_service = LabelService()
53+
model_service = ModelService()
5354
dataset_service = DatasetService(data_dir=data_dir, label_service=label_service)
5455
job_runnable_factory.register(
5556
JobType.TRAIN,
@@ -59,6 +60,7 @@ def setup_job_controller(data_dir: Path, max_parallel_jobs: int) -> tuple[JobQue
5960
subset_service=subset_service,
6061
subset_assigner=subset_assigner,
6162
dataset_service=dataset_service,
63+
model_service=model_service,
6264
data_dir=data_dir,
6365
db_session_factory=get_db_session,
6466
),

application/backend/app/services/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .dispatch_service import DispatchService
1616
from .label_service import LabelService
1717
from .metrics_service import MetricsService
18-
from .model_service import ModelService
18+
from .model_service import ModelRevisionMetadata, ModelService
1919
from .pipeline_metrics_service import PipelineMetricsService
2020
from .pipeline_service import PipelineService
2121
from .project_service import ProjectService
@@ -32,6 +32,7 @@
3232
"DispatchService",
3333
"LabelService",
3434
"MetricsService",
35+
"ModelRevisionMetadata",
3536
"ModelService",
3637
"PipelineMetricsService",
3738
"PipelineService",

application/backend/app/services/dataset_service.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def _get_image_path(item: DatasetItem) -> str:
375375
get_image_path=_get_image_path,
376376
)
377377

378-
def save_revision(self, project_id: UUID, dataset: dm.Dataset) -> None:
378+
def save_revision(self, project_id: UUID, dataset: dm.Dataset) -> UUID:
379379
"""
380380
Saves the dataset as a new revision.
381381
@@ -387,7 +387,7 @@ def save_revision(self, project_id: UUID, dataset: dm.Dataset) -> None:
387387
dataset: The Datumaro dataset to export.
388388
389389
Returns:
390-
None
390+
UUID: The UUID of the newly created dataset revision.
391391
"""
392392
revision_repo = DatasetRevisionRepository(db=self.db_session)
393393
revision_db = revision_repo.save(
@@ -403,3 +403,4 @@ def save_revision(self, project_id: UUID, dataset: dm.Dataset) -> None:
403403
export_images=True,
404404
as_zip=True,
405405
)
406+
return UUID(revision_db.id)

application/backend/app/services/model_service.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
# Copyright (C) 2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from dataclasses import dataclass
45
from uuid import UUID
56

67
from sqlalchemy.exc import IntegrityError
7-
from sqlalchemy.orm import Session
88

9-
from app.repositories import ModelRevisionRepository, ProjectRepository
9+
from app.db.schema import ModelRevisionDB
10+
from app.models.training_configuration.configuration import TrainingConfiguration
11+
from app.repositories import LabelRepository, ModelRevisionRepository, ProjectRepository
1012
from app.schemas.model import Model as ModelSchema
13+
from app.schemas.model import TrainingStatus
1114

12-
from .base import ResourceInUseError, ResourceNotFoundError, ResourceType
15+
from .base import BaseSessionManagedService, ResourceInUseError, ResourceNotFoundError, ResourceType
1316
from .mappers.model_revision_mapper import ModelRevisionMapper
1417
from .parent_process_guard import parent_process_only
1518

1619

17-
class ModelService:
18-
"""Service to register and activate models"""
20+
@dataclass(frozen=True)
21+
class ModelRevisionMetadata:
22+
model_id: UUID
23+
project_id: UUID
24+
architecture_id: str
25+
parent_revision_id: UUID | None
26+
dataset_revision_id: UUID | None
27+
training_status: TrainingStatus
28+
training_configuration: TrainingConfiguration | None = None
29+
1930

20-
def __init__(self, db_session: Session) -> None:
21-
self._db_session = db_session
31+
class ModelService(BaseSessionManagedService):
32+
"""Service to register and activate models"""
2233

2334
def get_model_by_id(self, project_id: UUID, model_id: UUID) -> ModelSchema:
2435
"""
@@ -39,7 +50,7 @@ def get_model_by_id(self, project_id: UUID, model_id: UUID) -> ModelSchema:
3950
ResourceNotFoundError: If the project with the given project_id does not exist,
4051
or if no model with the given model_id is found within the project.
4152
"""
42-
project_repo = ProjectRepository(self._db_session)
53+
project_repo = ProjectRepository(self.db_session)
4354
# Prefer using a JOIN here since the list of model revisions per project is not large,
4455
# and it allows us to check for project existence and fetch the model in a single query.
4556
project = project_repo.get_by_id(str(project_id))
@@ -72,10 +83,10 @@ def delete_model_by_id(self, project_id: UUID, model_id: UUID) -> None:
7283
ResourceInUseError: If the model cannot be deleted due to integrity constraints
7384
(e.g., the model is referenced by other entities).
7485
"""
75-
project_repo = ProjectRepository(self._db_session)
86+
project_repo = ProjectRepository(self.db_session)
7687
if not project_repo.exists(str(project_id)):
7788
raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id))
78-
model_rev_repo = ModelRevisionRepository(self._db_session)
89+
model_rev_repo = ModelRevisionRepository(self.db_session)
7990
try:
8091
# TODO: delete model artifacts from filesystem when implemented
8192
deleted = model_rev_repo.delete(str(model_id))
@@ -102,8 +113,40 @@ def list_models(self, project_id: UUID) -> list[ModelSchema]:
102113
Raises:
103114
ResourceNotFoundError: If the project with the given project_id does not exist.
104115
"""
105-
project_repo = ProjectRepository(self._db_session)
116+
project_repo = ProjectRepository(self.db_session)
106117
project = project_repo.get_by_id(str(project_id))
107118
if not project:
108119
raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id))
109120
return [ModelRevisionMapper.to_schema(model_rev_db) for model_rev_db in project.model_revisions]
121+
122+
def create_revision(self, metadata: ModelRevisionMetadata) -> None:
123+
"""
124+
Create and persist a new model revision for the given project metadata.
125+
126+
Reads the project's label definitions, serializes them into a dict format,
127+
combines them with the provided metadata into a new model revision record,
128+
and saves it to the database.
129+
130+
Args:
131+
metadata (ModelRevisionMetadata): Metadata used to create the new model revision
132+
including project id, architecture, optional parent revision id,
133+
dataset revision id, training status and optional training
134+
configuration.
135+
"""
136+
label_repo = LabelRepository(project_id=str(metadata.project_id), db=self.db_session)
137+
labels_schema_rev = {"labels": [{"name": label.name, "id": label.id} for label in label_repo.list_all()]}
138+
model_revision_repo = ModelRevisionRepository(self.db_session)
139+
model_revision_repo.save(
140+
ModelRevisionDB(
141+
id=str(metadata.model_id),
142+
project_id=str(metadata.project_id),
143+
architecture=metadata.architecture_id,
144+
parent_revision=str(metadata.parent_revision_id) if metadata.parent_revision_id else None,
145+
training_status=metadata.training_status,
146+
training_configuration=metadata.training_configuration.model_dump()
147+
if metadata.training_configuration
148+
else {},
149+
training_dataset_id=str(metadata.dataset_revision_id),
150+
label_schema_revision=labels_schema_rev,
151+
)
152+
)

application/backend/app/services/training/otx_trainer.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from collections.abc import Callable
66
from contextlib import AbstractContextManager
7+
from dataclasses import dataclass
78
from pathlib import Path
89
from uuid import UUID
910

@@ -14,14 +15,25 @@
1415

1516
from app.core.run import ExecutionContext
1617
from app.models import DatasetItemAnnotationStatus
17-
from app.services import BaseWeightsService, DatasetService
18+
from app.schemas.model import TrainingStatus
19+
from app.schemas.project import TaskBase
20+
from app.services import BaseWeightsService, DatasetService, ModelRevisionMetadata, ModelService
1821

1922
from .base import Trainer, step
23+
from .models import TrainingParams
2024
from .subset_assignment import SplitRatios, SubsetAssigner, SubsetService
2125

2226
MODEL_WEIGHTS_PATH = "model_weights_path"
2327

2428

29+
@dataclass(frozen=True)
30+
class DatasetInfo:
31+
training: Dataset
32+
validation: Dataset
33+
testing: Dataset
34+
revision_id: UUID
35+
36+
2537
class OTXTrainer(Trainer):
2638
"""OTX-specific trainer implementation."""
2739

@@ -31,6 +43,7 @@ def __init__(
3143
base_weights_service: BaseWeightsService,
3244
subset_service: SubsetService,
3345
dataset_service: DatasetService,
46+
model_service: ModelService,
3447
subset_assigner: SubsetAssigner,
3548
db_session_factory: Callable[[], AbstractContextManager[Session]],
3649
):
@@ -39,26 +52,22 @@ def __init__(
3952
self._base_weights_service = base_weights_service
4053
self._subset_service = subset_service
4154
self._dataset_service = dataset_service
55+
self._model_service = model_service
4256
self._subset_assigner = subset_assigner
4357
self._db_session_factory = db_session_factory
44-
self._training_dataset: Dataset | None = None
45-
self._validation_dataset: Dataset | None = None
46-
self._testing_dataset: Dataset | None = None
4758

4859
@step("Prepare Model Weights")
49-
def prepare_weights(self) -> Path:
60+
def prepare_weights(self, training_params: TrainingParams) -> Path:
5061
"""
5162
Prepare weights for training based on training parameters.
5263
5364
If a parent model revision ID is provided, it fetches the weights from the parent model.
5465
Otherwise, it retrieves the base weights for the specified model architecture.
5566
"""
56-
if self._training_params is None:
57-
raise ValueError("Training parameters not set")
58-
parent_model_revision_id = self._training_params.parent_model_revision_id
59-
task = self._training_params.task
60-
model_architecture_id = self._training_params.model_architecture_id
61-
project_id = self._training_params.project_id
67+
parent_model_revision_id = training_params.parent_model_revision_id
68+
task = training_params.task
69+
model_architecture_id = training_params.model_architecture_id
70+
project_id = training_params.project_id
6271
if parent_model_revision_id is None:
6372
return self._base_weights_service.get_local_weights_path(
6473
task=task.task_type, model_manifest_id=model_architecture_id
@@ -74,17 +83,11 @@ def prepare_weights(self) -> Path:
7483
return weights_path
7584

7685
@step("Assign Dataset Subsets")
77-
def assign_subsets(self) -> None:
86+
def assign_subsets(self, project_id: UUID) -> None:
7887
"""Assigning subsets to all unassigned dataset items in the project dataset."""
79-
if self._training_params is None:
80-
raise ValueError("Training parameters not set")
81-
project_id = self._training_params.project_id
82-
self.report_progress("Retrieving unassigned items")
83-
if project_id is None:
84-
raise ValueError("Project ID must be provided for subset assignment")
85-
8688
with self._db_session_factory() as db:
8789
self._subset_service.set_db_session(db)
90+
self.report_progress("Retrieving unassigned items")
8891
unassigned_items = self._subset_service.get_unassigned_items_with_labels(project_id)
8992

9093
if not unassigned_items:
@@ -112,30 +115,41 @@ def assign_subsets(self) -> None:
112115
self.report_progress(f"Successfully assigned {len(assignments)} items to subsets")
113116

114117
@step("Create Training Dataset")
115-
def create_training_dataset(self) -> None:
118+
def create_training_dataset(self, project_id: UUID, task: TaskBase) -> DatasetInfo:
116119
"""Create datasets for training, validation, and testing."""
117-
if self._training_params is None:
118-
raise ValueError("Training parameters not set")
119-
project_id = self._training_params.project_id
120-
if project_id is None:
121-
raise ValueError("Project ID must be provided")
122-
task = self._training_params.task
123-
124120
with self._db_session_factory() as db:
125121
self._dataset_service.set_db_session(db)
126122
dm_dataset = self._dataset_service.get_dm_dataset(project_id, task, DatasetItemAnnotationStatus.REVIEWED)
127-
self._training_dataset = dm_dataset.filter_by_subset(Subset.TRAINING)
128-
self._validation_dataset = dm_dataset.filter_by_subset(Subset.VALIDATION)
129-
self._testing_dataset = dm_dataset.filter_by_subset(Subset.TESTING)
130-
self._dataset_service.save_revision(project_id, dm_dataset)
123+
return DatasetInfo(
124+
training=dm_dataset.filter_by_subset(Subset.TRAINING),
125+
validation=dm_dataset.filter_by_subset(Subset.VALIDATION),
126+
testing=dm_dataset.filter_by_subset(Subset.TESTING),
127+
revision_id=self._dataset_service.save_revision(project_id, dm_dataset),
128+
)
129+
130+
@step("Prepare Model Metadata")
131+
def prepare_model(self, training_params: TrainingParams, dataset_revision_id: UUID) -> None:
132+
if training_params.project_id is None:
133+
raise ValueError("Project ID must be provided for model preparation")
134+
with self._db_session_factory() as db:
135+
self._model_service.set_db_session(db)
136+
self._model_service.create_revision(
137+
ModelRevisionMetadata(
138+
model_id=training_params.model_id,
139+
project_id=training_params.project_id,
140+
architecture_id=training_params.model_architecture_id,
141+
parent_revision_id=training_params.parent_model_revision_id,
142+
training_configuration=None, # TODO: to be set when config is added
143+
dataset_revision_id=dataset_revision_id,
144+
training_status=TrainingStatus.NOT_STARTED,
145+
)
146+
)
131147

132148
@step("Train Model with OTX")
133-
def train_model(self) -> None:
149+
def train_model(self, training_params: TrainingParams) -> None:
134150
"""Execute OTX model training."""
135-
if self._training_params is None:
136-
raise ValueError("Training parameters not set")
137151
# Simulate training with progress reporting
138-
job_id = self._training_params.job_id
152+
job_id = training_params.job_id
139153
step_count = 20
140154
for i in range(step_count):
141155
time.sleep(1)
@@ -145,12 +159,17 @@ def train_model(self) -> None:
145159

146160
def run(self, ctx: ExecutionContext) -> None:
147161
self._ctx = ctx
148-
self._training_params = self._get_training_params(ctx)
149-
150-
self.prepare_weights()
151-
self.assign_subsets()
152-
self.create_training_dataset()
153-
self.train_model()
162+
training_params = self._get_training_params(ctx)
163+
project_id = training_params.project_id
164+
if project_id is None:
165+
raise ValueError("Project ID must be provided in training parameters")
166+
task = training_params.task
167+
168+
self.prepare_weights(training_params)
169+
self.assign_subsets(project_id)
170+
dataset_info = self.create_training_dataset(project_id, task)
171+
self.prepare_model(training_params, dataset_info.revision_id)
172+
self.train_model(training_params)
154173

155174
@staticmethod
156175
def __build_model_weights_path(data_dir: Path, project_id: UUID, model_id: UUID) -> Path:

application/backend/tests/integration/services/test_dataset_service.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,13 +1385,11 @@ def test_save_revision(
13851385
project, db_dataset_items = fxt_project_with_subset_items
13861386
dataset = fxt_dataset_service.get_dm_dataset(project.id, project.task, DatasetItemAnnotationStatus.REVIEWED)
13871387

1388-
fxt_dataset_service.save_revision(
1388+
revision_id = fxt_dataset_service.save_revision(
13891389
project_id=project.id,
13901390
dataset=dataset,
13911391
)
13921392

13931393
# Verify that a revision entry was created
1394-
db_revisions = db_session.query(DatasetRevisionDB).all()
1395-
assert len(db_revisions) == 1
1396-
revision_id = db_revisions[0].id
1397-
assert (fxt_projects_dir / str(project.id) / "dataset_revisions" / revision_id / "dataset.zip").exists()
1394+
assert db_session.get(DatasetRevisionDB, str(revision_id)) is not None
1395+
assert (fxt_projects_dir / str(project.id) / "dataset_revisions" / str(revision_id) / "dataset.zip").exists()

0 commit comments

Comments
 (0)