Skip to content

Commit 4f3ef49

Browse files
authored
Store training dataset (#5008)
1 parent e6dda5a commit 4f3ef49

File tree

6 files changed

+73
-3
lines changed

6 files changed

+73
-3
lines changed

application/backend/app/repositories/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from .dataset_item_repo import DatasetItemRepository
5+
from .dataset_revision_repo import DatasetRevisionRepository
56
from .label_repo import LabelRepository
67
from .model_revision_repo import ModelRevisionRepository
78
from .pipeline_repo import PipelineRepository
@@ -11,6 +12,7 @@
1112

1213
__all__ = [
1314
"DatasetItemRepository",
15+
"DatasetRevisionRepository",
1416
"LabelRepository",
1517
"ModelRevisionRepository",
1618
"PipelineRepository",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from sqlalchemy.orm import Session
5+
6+
from app.db.schema import DatasetRevisionDB
7+
from app.repositories.base import BaseRepository
8+
9+
10+
class DatasetRevisionRepository(BaseRepository[DatasetRevisionDB]):
11+
"""Repository for dataset revision-related database operations."""
12+
13+
def __init__(self, db: Session):
14+
super().__init__(db, DatasetRevisionDB)

application/backend/app/services/dataset_service.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414
import datumaro.experimental as dm
1515
import numpy as np
16+
from datumaro.experimental.export_import import export_dataset
1617
from loguru import logger
1718
from PIL import Image, UnidentifiedImageError
1819
from sqlalchemy.orm import Session
1920

20-
from app.db.schema import DatasetItemDB
21+
from app.db.schema import DatasetItemDB, DatasetRevisionDB
2122
from app.models import (
2223
DatasetItem,
2324
DatasetItemAnnotation,
@@ -29,7 +30,7 @@
2930
Rectangle,
3031
TaskType,
3132
)
32-
from app.repositories import DatasetItemRepository
33+
from app.repositories import DatasetItemRepository, DatasetRevisionRepository
3334
from app.schemas.project import ProjectBase, ProjectView, TaskBase
3435
from app.services.datumaro_converter import convert_dataset
3536
from app.utils.images import crop_to_thumbnail
@@ -373,3 +374,32 @@ def _get_image_path(item: DatasetItem) -> str:
373374
get_dataset_items=_get_dataset_items,
374375
get_image_path=_get_image_path,
375376
)
377+
378+
def save_revision(self, project_id: UUID, dataset: dm.Dataset) -> None:
379+
"""
380+
Saves the dataset as a new revision.
381+
382+
Creates a new dataset revision entry in the database and exports the dataset
383+
to a zip file in the project's revisions directory.
384+
385+
Args:
386+
project_id: The UUID of the project to save the revision for.
387+
dataset: The Datumaro dataset to export.
388+
389+
Returns:
390+
None
391+
"""
392+
revision_repo = DatasetRevisionRepository(db=self.db_session)
393+
revision_db = revision_repo.save(
394+
DatasetRevisionDB(
395+
project_id=str(project_id),
396+
)
397+
)
398+
revision_path = self.projects_dir / str(project_id) / "dataset_revisions" / revision_db.id
399+
logger.info("Saving dataset revision '{}' to '{}'", revision_db.id, revision_path)
400+
export_dataset(
401+
dataset=dataset,
402+
output_path=revision_path,
403+
export_images=True,
404+
as_zip=True,
405+
)

application/backend/app/services/training/otx_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def create_training_dataset(self) -> None:
127127
self._training_dataset = dm_dataset.filter_by_subset(Subset.TRAINING)
128128
self._validation_dataset = dm_dataset.filter_by_subset(Subset.VALIDATION)
129129
self._testing_dataset = dm_dataset.filter_by_subset(Subset.TESTING)
130+
self._dataset_service.save_revision(project_id, dm_dataset)
130131

131132
@step("Train Model with OTX")
132133
def train_model(self) -> None:

application/backend/tests/integration/services/test_dataset_service.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlalchemy import func, select
1313
from sqlalchemy.orm import Session
1414

15-
from app.db.schema import DatasetItemDB, DatasetItemLabelDB, PipelineDB
15+
from app.db.schema import DatasetItemDB, DatasetItemLabelDB, DatasetRevisionDB, PipelineDB
1616
from app.models import DatasetItemAnnotation, DatasetItemAnnotationStatus, DatasetItemSubset, LabelReference, Rectangle
1717
from app.schemas import PipelineView, ProjectView
1818
from app.services import LabelService, PipelineService, ProjectService
@@ -1373,3 +1373,25 @@ def test_subset_filter_verifies_data_correctness(
13731373
assert len(testing_items) == 1
13741374
for item in testing_items:
13751375
assert item.subset == DatasetItemSubset.TESTING
1376+
1377+
def test_save_revision(
1378+
self,
1379+
fxt_projects_dir: Path,
1380+
fxt_dataset_service: DatasetService,
1381+
fxt_project_with_subset_items: tuple[ProjectView, list[DatasetItemDB]],
1382+
db_session: Session,
1383+
) -> None:
1384+
"""Test saving a dataset revision."""
1385+
project, db_dataset_items = fxt_project_with_subset_items
1386+
dataset = fxt_dataset_service.get_dm_dataset(project.id, project.task, DatasetItemAnnotationStatus.REVIEWED)
1387+
1388+
fxt_dataset_service.save_revision(
1389+
project_id=project.id,
1390+
dataset=dataset,
1391+
)
1392+
1393+
# Verify that a revision entry was created
1394+
db_revisions = db_session.query(DatasetRevisionDB).all()
1395+
assert len(db_revisions) == 1
1396+
revision_id = db_revisions[0].id
1397+
assert (fxt_projects_dir / str(project.id) / "dataset_revisions" / revision_id / "dataset.zip").exists()

application/backend/tests/unit/services/training/test_otx_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def test_create_training_dataset_success(
319319
assert otx_trainer._training_dataset == mock_training_dataset
320320
assert otx_trainer._validation_dataset == mock_validation_dataset
321321
assert otx_trainer._testing_dataset == mock_testing_dataset
322+
fxt_dataset_service.save_revision.assert_called_once_with(project_id, mock_dm_dataset)
322323

323324
def test_create_training_dataset_without_project_id_raises_error(
324325
self,

0 commit comments

Comments
 (0)