Skip to content

Commit 2635291

Browse files
authored
Create training dataset (#4998)
1 parent 2c6bd6d commit 2635291

File tree

27 files changed

+630
-305
lines changed

27 files changed

+630
-305
lines changed

application/backend/app/api/dependencies.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,17 @@ def get_project_service(
217217
) -> ProjectService:
218218
"""Provides a ProjectService instance for managing projects."""
219219
return ProjectService(
220-
data_dir=data_dir, db_session=db, label_service=label_service, pipeline_service=pipeline_service
220+
data_dir=data_dir, label_service=label_service, pipeline_service=pipeline_service, db_session=db
221221
)
222222

223223

224224
def get_dataset_service(
225-
data_dir: Annotated[Path, Depends(get_data_dir)], db: Annotated[Session, Depends(get_db)]
225+
data_dir: Annotated[Path, Depends(get_data_dir)],
226+
label_service: Annotated[LabelService, Depends(get_label_service)],
227+
db: Annotated[Session, Depends(get_db)],
226228
) -> DatasetService:
227229
"""Provides a DatasetService instance."""
228-
return DatasetService(data_dir=data_dir, db_session=db)
230+
return DatasetService(data_dir=data_dir, label_service=label_service, db_session=db)
229231

230232

231233
def get_project(

application/backend/app/api/routers/datasets.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from app.models import DatasetItemAnnotationStatus, DatasetItemSubset
2121
from app.schemas import ProjectView
2222
from app.services import DatasetService, ResourceNotFoundError
23-
from app.services.dataset_service import AnnotationValidationError, InvalidImageError, SubsetAlreadyAssignedError
23+
from app.services.dataset_service import (
24+
AnnotationValidationError,
25+
DatasetItemFilters,
26+
InvalidImageError,
27+
SubsetAlreadyAssignedError,
28+
)
2429

2530
router = APIRouter(prefix="/api/projects/{project_id}/dataset/items", tags=["Datasets"])
2631

@@ -122,14 +127,16 @@ def list_dataset_items( # noqa: PLR0913
122127
subset=subset,
123128
)
124129
dataset_items = dataset_service.list_dataset_items(
125-
project=project,
126-
limit=limit,
127-
offset=offset,
128-
start_date=start_date,
129-
end_date=end_date,
130-
annotation_status=annotation_status,
131-
label_ids=labels,
132-
subset=subset,
130+
project_id=project.id,
131+
filters=DatasetItemFilters(
132+
limit=limit,
133+
offset=offset,
134+
start_date=start_date,
135+
end_date=end_date,
136+
annotation_status=annotation_status,
137+
label_ids=labels,
138+
subset=subset,
139+
),
133140
)
134141
return DatasetItemsWithPagination(
135142
items=[DatasetItemView.model_validate(dataset_item, from_attributes=True) for dataset_item in dataset_items],
@@ -157,7 +164,7 @@ def get_dataset_item(
157164
) -> DatasetItemView:
158165
"""Get information about a specific dataset item"""
159166
try:
160-
dataset_item = dataset_service.get_dataset_item_by_id(project=project, dataset_item_id=dataset_item_id)
167+
dataset_item = dataset_service.get_dataset_item_by_id(project_id=project.id, dataset_item_id=dataset_item_id)
161168
return DatasetItemView.model_validate(dataset_item, from_attributes=True)
162169
except ResourceNotFoundError as e:
163170
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@@ -179,7 +186,7 @@ def get_dataset_item_binary(
179186
"""Get dataset item binary content"""
180187
try:
181188
binary_path = dataset_service.get_dataset_item_binary_path_by_id(
182-
project=project, dataset_item_id=dataset_item_id
189+
project_id=project.id, dataset_item_id=dataset_item_id
183190
)
184191
return FileResponse(path=binary_path)
185192
except ResourceNotFoundError as e:
@@ -281,7 +288,7 @@ def get_dataset_item_annotations(
281288
) -> DatasetItemAnnotations:
282289
"""Get the dataset item annotations"""
283290
try:
284-
dataset_item = dataset_service.get_dataset_item_by_id(project=project, dataset_item_id=dataset_item_id)
291+
dataset_item = dataset_service.get_dataset_item_by_id(project_id=project.id, dataset_item_id=dataset_item_id)
285292
if dataset_item.annotation_data is None:
286293
raise HTTPException(
287294
status_code=status.HTTP_404_NOT_FOUND, detail="Dataset item has not been annotated yet."
@@ -336,7 +343,7 @@ def assign_dataset_item_subset(
336343
"""Assign dataset item subset"""
337344
try:
338345
dataset_item = dataset_service.assign_dataset_item_subset(
339-
project=project, dataset_item_id=dataset_item_id, subset=subset_config.subset
346+
project_id=project.id, dataset_item_id=dataset_item_id, subset=subset_config.subset
340347
)
341348
return DatasetItemView.model_validate(dataset_item, from_attributes=True)
342349
except ResourceNotFoundError as e:

application/backend/app/api/routers/jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ async def submit_job(
6565
params=TrainingParams(
6666
model_architecture_id=job_request.parameters.model_architecture_id,
6767
parent_model_revision_id=job_request.parameters.parent_model_revision_id,
68-
task_type=project.task.task_type,
68+
task=project.task,
6969
),
7070
)
7171
case _:

application/backend/app/core/logging/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from .config import LogConfig
5+
from .handlers import InterceptHandler
56
from .setup import setup_logging, setup_uvicorn_logging
67
from .utils import logging_ctx
78

89
__all__ = [
10+
"InterceptHandler",
911
"LogConfig",
1012
"logging_ctx",
1113
"setup_logging",

application/backend/app/lifecycle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from app.db import MigrationManager, get_db_session
2020
from app.scheduler import Scheduler
2121
from app.schemas.job import JobType
22+
from app.services import DatasetService, LabelService
2223
from app.services.base_weights_service import BaseWeightsService
2324
from app.services.data_collect import DataCollector
2425
from app.services.event.event_bus import EventBus
@@ -48,13 +49,16 @@ def setup_job_controller(data_dir: Path, max_parallel_jobs: int) -> tuple[JobQue
4849
base_weights_service = BaseWeightsService(data_dir=data_dir)
4950
subset_service = SubsetService()
5051
subset_assigner = SubsetAssigner()
52+
label_service = LabelService()
53+
dataset_service = DatasetService(data_dir=data_dir, label_service=label_service)
5154
job_runnable_factory.register(
5255
JobType.TRAIN,
5356
partial(
5457
OTXTrainer,
5558
base_weights_service=base_weights_service,
5659
subset_service=subset_service,
5760
subset_assigner=subset_assigner,
61+
dataset_service=dataset_service,
5862
data_dir=data_dir,
5963
db_session_factory=get_db_session,
6064
),

application/backend/app/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# - docker compose up
99

1010
import importlib
11+
import logging
1112
import pkgutil
1213
from pathlib import Path
1314

@@ -18,10 +19,12 @@
1819
from loguru import logger
1920

2021
from app.api import routers
22+
from app.core.logging import InterceptHandler
2123
from app.lifecycle import lifespan
2224
from app.settings import get_settings
2325

2426
settings = get_settings()
27+
logging.basicConfig(handlers=[InterceptHandler()], level=settings.log_level, force=True)
2528
app = FastAPI(
2629
title=settings.app_name,
2730
version=settings.version,

application/backend/app/services/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
from .active_model_service import ActiveModelService
55
from .base import (
6+
BaseSessionManagedService,
67
ResourceInUseError,
78
ResourceNotFoundError,
89
ResourceType,
910
ResourceWithIdAlreadyExistsError,
1011
ResourceWithNameAlreadyExistsError,
1112
)
13+
from .base_weights_service import BaseWeightsService
1214
from .dataset_service import DatasetService
1315
from .dispatch_service import DispatchService
1416
from .label_service import LabelService
@@ -24,6 +26,8 @@
2426

2527
__all__ = [
2628
"ActiveModelService",
29+
"BaseSessionManagedService",
30+
"BaseWeightsService",
2731
"DatasetService",
2832
"DispatchService",
2933
"LabelService",

application/backend/app/services/base.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# Copyright (C) 2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from abc import ABC
5+
from collections.abc import Callable
46
from enum import StrEnum
57

8+
from sqlalchemy.orm import Session
9+
610

711
class ResourceType(StrEnum):
812
"""Enumeration for resource types."""
@@ -55,3 +59,65 @@ class ResourceWithIdAlreadyExistsError(ResourceError):
5559
def __init__(self, resource_type: ResourceType, resource_id: str, message: str | None = None):
5660
msg = message or f"{resource_type} with ID '{resource_id}' already exists."
5761
super().__init__(resource_type, resource_id, msg)
62+
63+
64+
class BaseSessionManagedService(ABC):
65+
"""
66+
Base class for services that require a managed database session.
67+
68+
This class supports deferred database session initialization, allowing services
69+
to be instantiated without an immediate database connection. The session can be
70+
provided either at construction time or injected later via `set_db_session()`.
71+
72+
This pattern is useful in scenarios where:
73+
- Services need to be created before database context is available
74+
- Database session management is handled externally (e.g., via session factories)
75+
- Services are used in contexts with different session lifecycle requirements
76+
77+
Args:
78+
db_session: Optional database session to use immediately. If not provided,
79+
the session must be set later via `set_db_session()` or a factory must be provided.
80+
db_session_factory: Optional callable that returns a database session when invoked.
81+
Used as a fallback if no session is directly provided.
82+
83+
Raises:
84+
RuntimeError: When accessing `db_session` property without a session or factory configured.
85+
86+
Example:
87+
>>> # With immediate session
88+
>>> service = MyService(db_session=session)
89+
>>>
90+
>>> # With deferred session
91+
>>> service = MyService()
92+
>>> service.set_db_session(session)
93+
>>>
94+
>>> # With session factory
95+
>>> service = MyService(db_session_factory=lambda: get_session())
96+
"""
97+
98+
def __init__(
99+
self,
100+
db_session: Session | None = None,
101+
db_session_factory: Callable[[], Session] | None = None,
102+
):
103+
self._db_session: Session | None = db_session
104+
self._db_session_factory = db_session_factory
105+
self._session_managed_services: list[BaseSessionManagedService] = []
106+
107+
def set_db_session(self, db_session: Session) -> None:
108+
"""Set the database session for the service."""
109+
self._db_session = db_session
110+
for service in self._session_managed_services:
111+
service.set_db_session(db_session)
112+
113+
def register_managed_services(self, *services: "BaseSessionManagedService") -> None:
114+
"""Register a child service that also requires session management."""
115+
self._session_managed_services.extend(services)
116+
117+
@property
118+
def db_session(self) -> Session:
119+
if self._db_session is not None:
120+
return self._db_session
121+
if self._db_session_factory is not None:
122+
return self._db_session_factory()
123+
raise RuntimeError("No DB session available. Provide session or session factory.")

application/backend/app/services/data_collect/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,11 @@ def collect(
180180
return
181181
frame_data = cv2.cvtColor(frame_data, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
182182
with get_db_session() as session:
183-
labels = LabelService(db_session=session).list_all(project_id=project.id)
183+
label_service = LabelService(db_session=session)
184+
labels = label_service.list_all(project_id=project.id)
184185
annotations = convert_prediction(labels=labels, frame_data=frame_data, prediction=inference_data.prediction)
185186

186-
dataset_service = DatasetService(data_dir=self.data_dir, db_session=session)
187+
dataset_service = DatasetService(data_dir=self.data_dir, label_service=label_service, db_session=session)
187188
dataset_service.create_dataset_item(
188189
project=project,
189190
name=f"{timestamp:.4f}".replace(".", "_"),

0 commit comments

Comments
 (0)