Skip to content

Commit 62e7923

Browse files
authored
Created DB seeding for Instance Segmentation projects (#5005)
1 parent 8b3b5b8 commit 62e7923

File tree

3 files changed

+279
-98
lines changed

3 files changed

+279
-98
lines changed

application/backend/app/cli.py

Lines changed: 62 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,19 @@
44
"""Command line interface for interacting with the Geti Tune application."""
55

66
import sys
7-
from datetime import datetime, timedelta
87

98
import click
109

1110
from app.db import MigrationManager, get_db_session
12-
from app.db.schema import DatasetItemDB, LabelDB, ModelRevisionDB, PipelineDB, ProjectDB, SinkDB, SourceDB
13-
from app.models import (
14-
DisconnectedSinkConfig,
15-
DisconnectedSourceConfig,
16-
FixedRateDataCollectionPolicy,
17-
OutputFormat,
18-
SinkType,
19-
SourceType,
20-
TaskType,
21-
TrainingStatus,
11+
from app.db.schema import DatasetItemDB, ModelRevisionDB, ProjectDB, SinkDB, SourceDB
12+
from app.db_seeder import (
13+
_create_detection_labels,
14+
_create_pipeline_with_video_source,
15+
_create_project,
16+
_create_segmentation_labels,
17+
_create_shared_sinks_sources_folders,
2218
)
19+
from app.models import TaskType
2320
from app.settings import get_settings
2421

2522
settings = get_settings()
@@ -84,79 +81,68 @@ def check_db() -> None:
8481
@cli.command()
8582
@click.option("--with-model", default=False)
8683
def seed(with_model: bool) -> None:
87-
"""Seed the database with test data."""
84+
"""
85+
Seed the database with test data.
86+
87+
Args:
88+
with_model (bool): Whether to include pre-trained models in the seed data.
89+
"""
8890
# If the app is running, it needs to be restarted since it doesn't track direct DB changes
8991
# Fixed IDs are used to ensure consistency in tests
9092
click.echo("Seeding database with test data...")
91-
project_id = "9d6af8e8-6017-4ebe-9126-33aae739c5fa"
92-
with get_db_session() as db:
93-
project = ProjectDB(
94-
id=project_id,
95-
name="Test Project",
96-
task_type=TaskType.DETECTION,
97-
exclusive_labels=True,
93+
sources, sinks, folders = _create_shared_sinks_sources_folders()
94+
95+
# Project 1: Object Detection
96+
detection_project = _create_project(
97+
project_id="9d6af8e8-6017-4ebe-9126-33aae739c5fa",
98+
task_type=TaskType.DETECTION,
99+
exclusive_labels=True,
100+
)
101+
detection_labels = _create_detection_labels(project_id=detection_project.id)
102+
103+
# Project 2: Instance Segmentation
104+
segmentation_project = _create_project(
105+
project_id="a1b2c3d4-e5f6-7890-abcd-ef1234567890",
106+
task_type=TaskType.INSTANCE_SEGMENTATION,
107+
exclusive_labels=True,
108+
)
109+
segmentation_labels = _create_segmentation_labels(project_id=segmentation_project.id)
110+
111+
detection_pipeline = None
112+
instance_segmentation_pipeline = None
113+
if with_model:
114+
detection_pipeline = _create_pipeline_with_video_source(
115+
project_id=detection_project.id,
116+
source_id="f6b1ac22-e36c-4b36-9a23-62b0881e4223",
117+
source_name="Video Source - Detection",
118+
video_path="data/media/card-video.mp4",
119+
sink_id=folders.id,
120+
model_id="977eeb18-eaac-449d-bc80-e340fbe052ad",
121+
model_architecture="Object_Detection_SSD",
122+
labels=detection_labels,
98123
)
99-
db.add(project)
100-
db.flush()
101-
labels = [
102-
LabelDB(project_id=project_id, name="Clubs", color="#2d6311", hotkey="c"),
103-
LabelDB(project_id=project_id, name="Diamonds", color="#baa3b3", hotkey="d"),
104-
LabelDB(project_id=project_id, name="Spades", color="#000702", hotkey="s"),
105-
LabelDB(project_id=project_id, name="Hearts", color="#1f016b", hotkey="h"),
106-
LabelDB(project_id=project_id, name="No_object", color="#565a84", hotkey="n"),
107-
]
108-
db.add_all(labels)
109-
db.flush()
110124

111-
# Create default disconnected source and sink
112-
disconnected_source_cfg = DisconnectedSourceConfig()
113-
disconnected_source = SourceDB(
114-
id="00000000-0000-0000-0000-000000000000",
115-
name=disconnected_source_cfg.name,
116-
source_type=disconnected_source_cfg.source_type,
117-
config_data={},
118-
)
119-
disconnected_sink_cfg = DisconnectedSinkConfig()
120-
disconnected_sink = SinkDB(
121-
id="00000000-0000-0000-0000-000000000000",
122-
name=disconnected_sink_cfg.name,
123-
sink_type=disconnected_sink_cfg.sink_type,
124-
output_formats=[],
125-
config_data={},
126-
)
127-
folder_sink = SinkDB(
128-
id="6ee0c080-c7d9-4438-a7d2-067fd395eecf",
129-
name="Folder Sink",
130-
sink_type=SinkType.FOLDER,
131-
rate_limit=0.2,
132-
output_formats=[OutputFormat.IMAGE_ORIGINAL, OutputFormat.IMAGE_WITH_PREDICTIONS, OutputFormat.PREDICTIONS],
133-
config_data={"folder_path": "data/output"},
125+
instance_segmentation_pipeline = _create_pipeline_with_video_source(
126+
project_id=segmentation_project.id,
127+
source_id="b2c3d4e5-f6a7-8901-bcde-f12345678901",
128+
source_name="Video Source - Segmentation",
129+
video_path="data/media/fish-video.mp4",
130+
sink_id=folders.id,
131+
model_id="c3d4e5f6-a7b8-9012-cdef-123456789012",
132+
model_architecture="Custom_Instance_Segmentation_RTMDet_tiny",
133+
labels=segmentation_labels,
134134
)
135-
db.add_all([disconnected_source, disconnected_sink, folder_sink])
135+
136+
with get_db_session() as db:
137+
db.add_all([sources, sinks, folders, detection_project, segmentation_project])
138+
db.flush()
139+
db.add_all(detection_labels + segmentation_labels)
136140
db.flush()
137141

138-
pipeline = PipelineDB(project_id=project.id)
139-
pipeline.source = SourceDB(
140-
id="f6b1ac22-e36c-4b36-9a23-62b0881e4223",
141-
name="Video Source",
142-
source_type=SourceType.VIDEO_FILE,
143-
config_data={"video_path": "data/media/video.mp4"},
144-
)
145-
pipeline.sink_id = folder_sink.id
146-
pipeline.data_collection_policies = [FixedRateDataCollectionPolicy(rate=0.1).model_dump(mode="json")]
147-
if with_model:
148-
pipeline.model_revision = ModelRevisionDB(
149-
id="977eeb18-eaac-449d-bc80-e340fbe052ad",
150-
project_id=project.id,
151-
architecture="Object_Detection_SSD",
152-
training_status=TrainingStatus.SUCCESSFUL,
153-
training_started_at=datetime.now() - timedelta(hours=24),
154-
training_finished_at=datetime.now() - timedelta(hours=23),
155-
training_configuration={},
156-
label_schema_revision={"labels": [{"id": str(label.id), "name": label.name} for label in labels]},
157-
)
158-
pipeline.is_running = True
159-
db.add(pipeline)
142+
if with_model and detection_pipeline and instance_segmentation_pipeline:
143+
db.add_all([detection_pipeline, instance_segmentation_pipeline])
144+
db.flush()
145+
160146
db.commit()
161147
click.echo("✓ Seeding successful!")
162148

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from datetime import datetime, timedelta
6+
from uuid import UUID
7+
8+
from app.db.schema import LabelDB, ModelRevisionDB, PipelineDB, ProjectDB, SinkDB, SourceDB
9+
from app.models import (
10+
DisconnectedSinkConfig,
11+
DisconnectedSourceConfig,
12+
FixedRateDataCollectionPolicy,
13+
OutputFormat,
14+
SinkType,
15+
SourceType,
16+
TaskType,
17+
TrainingStatus,
18+
)
19+
20+
21+
def _create_shared_sinks_sources_folders() -> tuple[SourceDB, SinkDB, SinkDB]:
22+
"""
23+
Create shared source, sink, folder entities.
24+
25+
Returns:
26+
tuple[SourceDB, SinkDB, SinkDB]: Created source, sink, and folder sink objects.
27+
"""
28+
disconnected_source_cfg = DisconnectedSourceConfig()
29+
disconnected_source = SourceDB(
30+
id="00000000-0000-0000-0000-000000000000",
31+
name=disconnected_source_cfg.name,
32+
source_type=disconnected_source_cfg.source_type,
33+
config_data={},
34+
)
35+
disconnected_sink_cfg = DisconnectedSinkConfig()
36+
disconnected_sink = SinkDB(
37+
id="00000000-0000-0000-0000-000000000000",
38+
name=disconnected_sink_cfg.name,
39+
sink_type=disconnected_sink_cfg.sink_type,
40+
output_formats=[],
41+
config_data={},
42+
)
43+
folder_sink = SinkDB(
44+
id="6ee0c080-c7d9-4438-a7d2-067fd395eecf",
45+
name="Folder Sink",
46+
sink_type=SinkType.FOLDER,
47+
rate_limit=0.2,
48+
output_formats=[OutputFormat.IMAGE_ORIGINAL, OutputFormat.IMAGE_WITH_PREDICTIONS, OutputFormat.PREDICTIONS],
49+
config_data={"folder_path": "data/output"},
50+
)
51+
return disconnected_source, disconnected_sink, folder_sink
52+
53+
54+
def _create_project(
55+
project_id: str | UUID,
56+
task_type: TaskType,
57+
exclusive_labels: bool = True,
58+
) -> ProjectDB:
59+
"""
60+
Create a project in the database.
61+
62+
Args:
63+
project_id (str | UUID): Unique identifier for the project.
64+
task_type (TaskType): Type of task (e.g., DETECTION, INSTANCE_SEGMENTATION, CLASSIFICATION).
65+
exclusive_labels (bool): Whether labels are mutually exclusive.
66+
67+
Returns:
68+
ProjectDB: Created project object.
69+
"""
70+
return ProjectDB(
71+
id=project_id,
72+
name=f"Demo {task_type} project",
73+
task_type=task_type,
74+
exclusive_labels=exclusive_labels,
75+
)
76+
77+
78+
def _create_detection_labels(project_id: str | UUID) -> list[LabelDB]:
79+
"""
80+
Create labels for a Detection card project.
81+
82+
Args:
83+
project_id (str | UUID): ID of the project to add labels to.
84+
85+
Returns:
86+
list[LabelDB]: List of created label objects.
87+
"""
88+
return [
89+
LabelDB(project_id=project_id, name="Clubs", color="#2d6311", hotkey="c"),
90+
LabelDB(project_id=project_id, name="Diamonds", color="#baa3b3", hotkey="d"),
91+
LabelDB(project_id=project_id, name="Spades", color="#000702", hotkey="s"),
92+
LabelDB(project_id=project_id, name="Hearts", color="#1f016b", hotkey="h"),
93+
LabelDB(project_id=project_id, name="No_object", color="#565a84", hotkey="n"),
94+
]
95+
96+
97+
def _create_segmentation_labels(project_id: str | UUID) -> list[LabelDB]:
98+
"""
99+
Create labels for an Instance Segmentation fish project.
100+
101+
Args:
102+
project_id (str | UUID): ID of the project to add labels to.
103+
104+
Returns:
105+
list[LabelDB]: List of created label objects.
106+
"""
107+
return [
108+
LabelDB(project_id=project_id, name="Fish", color="#2d6311", hotkey="f"),
109+
LabelDB(project_id=project_id, name="Empty", color="#565a84", hotkey="e"),
110+
]
111+
112+
113+
def _create_pipeline_with_video_source( # noqa: PLR0913
114+
project_id: str | UUID,
115+
source_id: str | UUID,
116+
source_name: str,
117+
video_path: str,
118+
sink_id: str | UUID,
119+
model_id: str,
120+
model_architecture: str,
121+
labels: list[LabelDB],
122+
) -> PipelineDB:
123+
"""
124+
Create a pipeline with a video file source for a project.
125+
126+
Args:
127+
project_id (str | UUID): ID of the project.
128+
source_id (str | UUID): Unique identifier for the video source.
129+
source_name (str): Name for the video source.
130+
video_path (str): Path to the video file.
131+
sink_id (str | UUID): ID of the sink to use.
132+
model_id (str): Unique identifier for the model revision.
133+
model_architecture (str): Architecture name of the model.
134+
labels (list[LabelDB] | None): List of labels for the label schema revision.
135+
136+
Returns:
137+
PipelineDB: Created pipeline object.
138+
"""
139+
pipeline = PipelineDB(
140+
project_id=project_id,
141+
sink_id=sink_id,
142+
data_collection_policies=[FixedRateDataCollectionPolicy(rate=0.1).model_dump(mode="json")],
143+
is_running=True,
144+
)
145+
146+
pipeline.source = SourceDB(
147+
id=source_id,
148+
name=source_name,
149+
source_type=SourceType.VIDEO_FILE,
150+
config_data={"video_path": video_path},
151+
)
152+
153+
pipeline.model_revision = ModelRevisionDB(
154+
id=model_id,
155+
project_id=project_id,
156+
architecture=model_architecture,
157+
training_status=TrainingStatus.SUCCESSFUL,
158+
training_started_at=datetime.now() - timedelta(hours=24),
159+
training_finished_at=datetime.now() - timedelta(hours=23),
160+
training_configuration={},
161+
label_schema_revision={"labels": [{"id": str(label.id), "name": label.name} for label in labels]},
162+
)
163+
return pipeline

0 commit comments

Comments
 (0)