|
4 | 4 | """Command line interface for interacting with the Geti Tune application.""" |
5 | 5 |
|
6 | 6 | import sys |
7 | | -from datetime import datetime, timedelta |
8 | 7 |
|
9 | 8 | import click |
10 | 9 |
|
11 | 10 | from app.db import MigrationManager, get_db_session |
12 | | -from app.db.schema import DatasetItemDB, LabelDB, ModelRevisionDB, PipelineDB, ProjectDB, SinkDB, SourceDB |
13 | | -from app.models import ( |
14 | | - DisconnectedSinkConfig, |
15 | | - DisconnectedSourceConfig, |
16 | | - FixedRateDataCollectionPolicy, |
17 | | - OutputFormat, |
18 | | - SinkType, |
19 | | - SourceType, |
20 | | - TaskType, |
21 | | - TrainingStatus, |
| 11 | +from app.db.schema import DatasetItemDB, ModelRevisionDB, ProjectDB, SinkDB, SourceDB |
| 12 | +from app.db_seeder import ( |
| 13 | + _create_detection_labels, |
| 14 | + _create_pipeline_with_video_source, |
| 15 | + _create_project, |
| 16 | + _create_segmentation_labels, |
| 17 | + _create_shared_sinks_sources_folders, |
22 | 18 | ) |
| 19 | +from app.models import TaskType |
23 | 20 | from app.settings import get_settings |
24 | 21 |
|
25 | 22 | settings = get_settings() |
@@ -84,79 +81,68 @@ def check_db() -> None: |
84 | 81 | @cli.command() |
85 | 82 | @click.option("--with-model", default=False) |
86 | 83 | def seed(with_model: bool) -> None: |
87 | | - """Seed the database with test data.""" |
| 84 | + """ |
| 85 | + Seed the database with test data. |
| 86 | +
|
| 87 | + Args: |
| 88 | + with_model (bool): Whether to include pre-trained models in the seed data. |
| 89 | + """ |
88 | 90 | # If the app is running, it needs to be restarted since it doesn't track direct DB changes |
89 | 91 | # Fixed IDs are used to ensure consistency in tests |
90 | 92 | click.echo("Seeding database with test data...") |
91 | | - project_id = "9d6af8e8-6017-4ebe-9126-33aae739c5fa" |
92 | | - with get_db_session() as db: |
93 | | - project = ProjectDB( |
94 | | - id=project_id, |
95 | | - name="Test Project", |
96 | | - task_type=TaskType.DETECTION, |
97 | | - exclusive_labels=True, |
| 93 | + sources, sinks, folders = _create_shared_sinks_sources_folders() |
| 94 | + |
| 95 | + # Project 1: Object Detection |
| 96 | + detection_project = _create_project( |
| 97 | + project_id="9d6af8e8-6017-4ebe-9126-33aae739c5fa", |
| 98 | + task_type=TaskType.DETECTION, |
| 99 | + exclusive_labels=True, |
| 100 | + ) |
| 101 | + detection_labels = _create_detection_labels(project_id=detection_project.id) |
| 102 | + |
| 103 | + # Project 2: Instance Segmentation |
| 104 | + segmentation_project = _create_project( |
| 105 | + project_id="a1b2c3d4-e5f6-7890-abcd-ef1234567890", |
| 106 | + task_type=TaskType.INSTANCE_SEGMENTATION, |
| 107 | + exclusive_labels=True, |
| 108 | + ) |
| 109 | + segmentation_labels = _create_segmentation_labels(project_id=segmentation_project.id) |
| 110 | + |
| 111 | + detection_pipeline = None |
| 112 | + instance_segmentation_pipeline = None |
| 113 | + if with_model: |
| 114 | + detection_pipeline = _create_pipeline_with_video_source( |
| 115 | + project_id=detection_project.id, |
| 116 | + source_id="f6b1ac22-e36c-4b36-9a23-62b0881e4223", |
| 117 | + source_name="Video Source - Detection", |
| 118 | + video_path="data/media/card-video.mp4", |
| 119 | + sink_id=folders.id, |
| 120 | + model_id="977eeb18-eaac-449d-bc80-e340fbe052ad", |
| 121 | + model_architecture="Object_Detection_SSD", |
| 122 | + labels=detection_labels, |
98 | 123 | ) |
99 | | - db.add(project) |
100 | | - db.flush() |
101 | | - labels = [ |
102 | | - LabelDB(project_id=project_id, name="Clubs", color="#2d6311", hotkey="c"), |
103 | | - LabelDB(project_id=project_id, name="Diamonds", color="#baa3b3", hotkey="d"), |
104 | | - LabelDB(project_id=project_id, name="Spades", color="#000702", hotkey="s"), |
105 | | - LabelDB(project_id=project_id, name="Hearts", color="#1f016b", hotkey="h"), |
106 | | - LabelDB(project_id=project_id, name="No_object", color="#565a84", hotkey="n"), |
107 | | - ] |
108 | | - db.add_all(labels) |
109 | | - db.flush() |
110 | 124 |
|
111 | | - # Create default disconnected source and sink |
112 | | - disconnected_source_cfg = DisconnectedSourceConfig() |
113 | | - disconnected_source = SourceDB( |
114 | | - id="00000000-0000-0000-0000-000000000000", |
115 | | - name=disconnected_source_cfg.name, |
116 | | - source_type=disconnected_source_cfg.source_type, |
117 | | - config_data={}, |
118 | | - ) |
119 | | - disconnected_sink_cfg = DisconnectedSinkConfig() |
120 | | - disconnected_sink = SinkDB( |
121 | | - id="00000000-0000-0000-0000-000000000000", |
122 | | - name=disconnected_sink_cfg.name, |
123 | | - sink_type=disconnected_sink_cfg.sink_type, |
124 | | - output_formats=[], |
125 | | - config_data={}, |
126 | | - ) |
127 | | - folder_sink = SinkDB( |
128 | | - id="6ee0c080-c7d9-4438-a7d2-067fd395eecf", |
129 | | - name="Folder Sink", |
130 | | - sink_type=SinkType.FOLDER, |
131 | | - rate_limit=0.2, |
132 | | - output_formats=[OutputFormat.IMAGE_ORIGINAL, OutputFormat.IMAGE_WITH_PREDICTIONS, OutputFormat.PREDICTIONS], |
133 | | - config_data={"folder_path": "data/output"}, |
| 125 | + instance_segmentation_pipeline = _create_pipeline_with_video_source( |
| 126 | + project_id=segmentation_project.id, |
| 127 | + source_id="b2c3d4e5-f6a7-8901-bcde-f12345678901", |
| 128 | + source_name="Video Source - Segmentation", |
| 129 | + video_path="data/media/fish-video.mp4", |
| 130 | + sink_id=folders.id, |
| 131 | + model_id="c3d4e5f6-a7b8-9012-cdef-123456789012", |
| 132 | + model_architecture="Custom_Instance_Segmentation_RTMDet_tiny", |
| 133 | + labels=segmentation_labels, |
134 | 134 | ) |
135 | | - db.add_all([disconnected_source, disconnected_sink, folder_sink]) |
| 135 | + |
| 136 | + with get_db_session() as db: |
| 137 | + db.add_all([sources, sinks, folders, detection_project, segmentation_project]) |
| 138 | + db.flush() |
| 139 | + db.add_all(detection_labels + segmentation_labels) |
136 | 140 | db.flush() |
137 | 141 |
|
138 | | - pipeline = PipelineDB(project_id=project.id) |
139 | | - pipeline.source = SourceDB( |
140 | | - id="f6b1ac22-e36c-4b36-9a23-62b0881e4223", |
141 | | - name="Video Source", |
142 | | - source_type=SourceType.VIDEO_FILE, |
143 | | - config_data={"video_path": "data/media/video.mp4"}, |
144 | | - ) |
145 | | - pipeline.sink_id = folder_sink.id |
146 | | - pipeline.data_collection_policies = [FixedRateDataCollectionPolicy(rate=0.1).model_dump(mode="json")] |
147 | | - if with_model: |
148 | | - pipeline.model_revision = ModelRevisionDB( |
149 | | - id="977eeb18-eaac-449d-bc80-e340fbe052ad", |
150 | | - project_id=project.id, |
151 | | - architecture="Object_Detection_SSD", |
152 | | - training_status=TrainingStatus.SUCCESSFUL, |
153 | | - training_started_at=datetime.now() - timedelta(hours=24), |
154 | | - training_finished_at=datetime.now() - timedelta(hours=23), |
155 | | - training_configuration={}, |
156 | | - label_schema_revision={"labels": [{"id": str(label.id), "name": label.name} for label in labels]}, |
157 | | - ) |
158 | | - pipeline.is_running = True |
159 | | - db.add(pipeline) |
| 142 | + if with_model and detection_pipeline and instance_segmentation_pipeline: |
| 143 | + db.add_all([detection_pipeline, instance_segmentation_pipeline]) |
| 144 | + db.flush() |
| 145 | + |
160 | 146 | db.commit() |
161 | 147 | click.echo("✓ Seeding successful!") |
162 | 148 |
|
|
0 commit comments