Skip to content

Commit f9fa289

Browse files
authored
Bugfix: Model storage named after template instead of manifest (#1058)
1 parent 1835b00 commit f9fa289

File tree

21 files changed

+356
-131
lines changed

21 files changed

+356
-131
lines changed

interactive_ai/libs/iai_core_py/iai_core/utils/project_builder.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,23 @@ class ProjectBuilder:
6464
"""
6565

6666
@staticmethod
67-
def get_default_model_template_by_task_type(task_type: TaskType) -> ModelTemplate:
67+
def get_default_model_template_by_task_type(
68+
task_type: TaskType, default_models_per_task: dict[str, str]
69+
) -> ModelTemplate:
6870
"""
6971
Get the default ModelTemplate which belongs to the task type.
7072
7173
:param task_type: the task type which requires a model template
74+
:param default_models_per_task: a dictionary mapping task type names to default model template IDs
7275
:return: the ModelTemplate associated with the task type
7376
"""
74-
if task_type in [TaskType.DATASET, TaskType.CROP]:
75-
default_model_template = ModelTemplateList().get_by_id(task_type.name.lower())
76-
else:
77-
default_model_template = next(
78-
(
79-
model_template
80-
for model_template in ModelTemplateList().get_all()
81-
if model_template.task_type == task_type and model_template.is_default_for_task
82-
),
83-
NullModelTemplate(),
84-
)
77+
task_type_name = task_type.name.lower()
78+
model_template_id = (
79+
task_type_name
80+
if task_type in [TaskType.DATASET, TaskType.CROP]
81+
else default_models_per_task[task_type_name]
82+
)
83+
default_model_template = ModelTemplateList().get_by_id(model_template_id)
8584

8685
if isinstance(default_model_template, NullModelTemplate):
8786
raise ModelTemplateError("A NullModelTemplate was created.")
@@ -565,6 +564,7 @@ def build_full_project(
565564
creator_id: str,
566565
parser_class: type[ProjectParser],
567566
parser_kwargs: dict,
567+
default_models_per_task: dict[str, str],
568568
) -> tuple[Project, LabelSchema, dict[str, LabelSchemaView]]:
569569
"""
570570
This does NOT save anything to the database.
@@ -583,6 +583,7 @@ def build_full_project(
583583
:param creator_id: the ID of who created the project
584584
:param parser_class: a parser which will be used to get relevant information from the REST response
585585
:param parser_kwargs: arguments to pass to the parser for initialization
586+
:param default_models_per_task: a dictionary mapping task type names to default model template IDs
586587
:return: the project, the label schema, and a mapping of task to label schema view
587588
"""
588589
is_keypoint_detection_enabled = FeatureFlagProvider.is_enabled(FEATURE_FLAG_KEYPOINT_DETECTION)
@@ -608,7 +609,9 @@ def build_full_project(
608609
custom_labels_names = parser.get_custom_labels_names_by_task(
609610
task_name=task_name,
610611
)
611-
model_template = cls.get_default_model_template_by_task_type(task_type=task_type)
612+
model_template = cls.get_default_model_template_by_task_type(
613+
task_type=task_type, default_models_per_task=default_models_per_task
614+
)
612615
task_node = cls._build_task_node(
613616
project_id=project_id,
614617
task_name=task_name,
@@ -1060,7 +1063,11 @@ class PersistedProjectBuilder(ProjectBuilder):
10601063

10611064
@classmethod
10621065
def build_full_project(
1063-
cls, creator_id: str, parser_class: type[ProjectParser], parser_kwargs: dict
1066+
cls,
1067+
creator_id: str,
1068+
parser_class: type[ProjectParser],
1069+
parser_kwargs: dict,
1070+
default_models_per_task: dict[str, str],
10641071
) -> tuple[Project, LabelSchema, dict[str, LabelSchemaView]]:
10651072
"""
10661073
Overrides the method from the super ProjectBuilder
@@ -1070,13 +1077,19 @@ def build_full_project(
10701077
:param creator_id: the ID of who created the project
10711078
:param parser_class: a parser which will be used to get relevant information from the REST response
10721079
:param parser_kwargs: arguments to pass to the parser for initialization
1080+
:param default_models_per_task: a dictionary mapping task type names to default model template IDs
10731081
:return: the project, the label schema, and a mapping of task to label schema view
10741082
"""
10751083
(
10761084
project,
10771085
label_schema,
10781086
task_to_label_schema_view,
1079-
) = super().build_full_project(creator_id=creator_id, parser_class=parser_class, parser_kwargs=parser_kwargs)
1087+
) = super().build_full_project(
1088+
creator_id=creator_id,
1089+
parser_class=parser_class,
1090+
parser_kwargs=parser_kwargs,
1091+
default_models_per_task=default_models_per_task,
1092+
)
10801093
return project, label_schema, task_to_label_schema_view
10811094

10821095
@classmethod

interactive_ai/libs/iai_core_py/tests/utils/test_project_builder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_detection_label_schema_from_pipeline(
9999
creator_id="Geti",
100100
parser_class=CustomTestProjectParser,
101101
parser_kwargs=fxt_detection_classification_project_data,
102+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
102103
)
103104
detection_label_group = label_schema.get_label_group_by_name("Default Detection")
104105
classification_label_group = label_schema.get_label_group_by_name("Default Classification")
@@ -127,6 +128,7 @@ def test_anomaly_label_schema_from_pipeline(self, fxt_anomaly_project_data):
127128
creator_id="Geti",
128129
parser_class=CustomTestProjectParser,
129130
parser_kwargs=fxt_anomaly_project_data,
131+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
130132
)
131133

132134
anomaly_label_group = label_schema.get_label_group_by_name("default - anomaly")
@@ -176,6 +178,7 @@ def test_label_schema_with_hierarchy_task_chain_from_pipeline(
176178
creator_id="Geti",
177179
parser_class=CustomTestProjectParser,
178180
parser_kwargs=fxt_label_hierarchy_task_chain_project_data,
181+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
179182
)
180183

181184
expected_group_names = {
@@ -213,6 +216,7 @@ def test_label_schema_with_hierarchy_from_pipeline(
213216
creator_id="Geti",
214217
parser_class=CustomTestProjectParser,
215218
parser_kwargs=fxt_hierarchy_classification_project_data_2,
219+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
216220
)
217221

218222
expected_group_names = {
@@ -268,6 +272,7 @@ def test_label_schema_for_binary_pipeline(
268272
creator_id="Geti",
269273
parser_class=CustomTestProjectParser,
270274
parser_kwargs=fxt_binary_classification_project_data,
275+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
271276
)
272277

273278
label_groups = label_schema.get_groups(include_empty=True)
@@ -291,6 +296,7 @@ def test_label_schema_for_multiclass_pipeline(
291296
creator_id="Geti",
292297
parser_class=CustomTestProjectParser,
293298
parser_kwargs=fxt_multiclass_classification_project_data,
299+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
294300
)
295301
label_groups = label_schema.get_groups(include_empty=True)
296302
labels = label_schema.get_labels(include_empty=True)
@@ -317,6 +323,7 @@ def test_label_schema_for_multilabel_pipeline(
317323
creator_id="Geti",
318324
parser_class=CustomTestProjectParser,
319325
parser_kwargs=fxt_multilabel_classification_project_data,
326+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
320327
)
321328

322329
label_groups = label_schema.get_groups(include_empty=True)
@@ -339,7 +346,10 @@ def test_label_schema_for_segmentation_pipeline(
339346
],
340347
):
341348
project, label_schema, task_schema = ProjectBuilder.build_full_project(
342-
creator_id="Geti", parser_class=CustomTestProjectParser, parser_kwargs=fxt_segmentation_project_data
349+
creator_id="Geti",
350+
parser_class=CustomTestProjectParser,
351+
parser_kwargs=fxt_segmentation_project_data,
352+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
343353
)
344354

345355
label_groups = label_schema.get_groups(include_empty=True)
@@ -391,6 +401,7 @@ def test_detection_to_segmentation_labels_from_pipeline(
391401
creator_id="Geti",
392402
parser_class=CustomTestProjectParser,
393403
parser_kwargs=fxt_detection_to_segmentation_project_data,
404+
default_models_per_task={}, # not necessary as get_default_model_template_by_task_type is mocked
394405
)
395406
expected_group_names = {
396407
"Default Detection",

interactive_ai/services/director/tests/fixtures/database.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import jsonschema
1919
import numpy as np
2020
import pytest
21+
from geti_supported_models.default_models import DefaultModels
2122

2223
from coordination.configuration_manager.task_node_config import TaskNodeConfig
2324

@@ -721,6 +722,7 @@ def create_empty_project(
721722
creator_id="",
722723
parser_class=RestProjectParser,
723724
parser_kwargs=parser_kwargs,
725+
default_models_per_task=DefaultModels.get_default_models_per_task(),
724726
)
725727

726728
return self._project

interactive_ai/services/resource/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ENV UV_PYTHON_DOWNLOADS=0
2121
COPY --link --from=libs . libs
2222
COPY --link --from=iai_core . interactive_ai/libs/iai_core_py
2323
COPY --link --from=media_utils . interactive_ai/libs/media_utils
24+
COPY --link --from=supported_models . interactive_ai/supported_models
2425

2526
WORKDIR /interactive_ai/services/resource
2627

interactive_ai/services/resource/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ DOCKER_BUILD_CONTEXT := --build-context schemas=../api \
55
--build-context libs=../../../libs \
66
--build-context iai_core=../../../interactive_ai/libs/iai_core_py \
77
--build-context media_utils=../../../interactive_ai/libs/media_utils \
8+
--build-context supported_models=../../../interactive_ai/supported_models \

interactive_ai/services/resource/app/communication/rest_views/model_rest_views.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from enum import Enum, auto
77
from typing import TYPE_CHECKING, Any, cast
88

9+
from geti_supported_models import SupportedModels
10+
911
from communication.rest_views.label_rest_views import LabelRESTViews
1012
from communication.rest_views.performance_rest_views import PerformanceRESTViews
1113

@@ -125,12 +127,13 @@ def model_storage_to_rest(
125127
:return: dictionary with REST representation of the model storage
126128
"""
127129
models_rest_list: list[dict] = []
130+
model_architecture_name = ModelRESTViews._get_model_architecture_name(model_storage=model_storage)
128131
for model_info in per_model_info:
129132
model = model_info.model
130133
models_rest_list.append(
131134
{
132135
ID_: str(model.id_),
133-
NAME: model.model_storage.name,
136+
NAME: model_architecture_name,
134137
CREATION_DATE: model.creation_date.isoformat(),
135138
SCORE_UP_TO_DATE: True, # this field is deprecated, value is always 'True' and should not be used
136139
ACTIVE_MODEL: model == active_model,
@@ -151,7 +154,7 @@ def model_storage_to_rest(
151154

152155
rest_view = {
153156
ID_: str(model_storage.id_),
154-
NAME: model_storage.name,
157+
NAME: model_architecture_name,
155158
MODEL_TEMPLATE_ID: model_storage.model_template.model_template_id,
156159
TASK_ID: str(task_node_id),
157160
MODELS: models_rest_list,
@@ -186,7 +189,7 @@ def model_to_rest(
186189
"""
187190
return {
188191
ID_: str(model.id_),
189-
NAME: model.model_storage.name,
192+
NAME: ModelRESTViews._get_model_architecture_name(model_storage=model.model_storage),
190193
CREATION_DATE: model.creation_date.isoformat(),
191194
SCORE_UP_TO_DATE: True, # Note: this field is deprecated, value is always 'True' and should not be used
192195
ACTIVE_MODEL: model == active_model,
@@ -233,10 +236,11 @@ def model_info_to_rest(
233236
"""
234237
label_schema = model.configuration.get_label_schema()
235238
labels = label_schema.get_labels(include_empty=True)
239+
model_architecture_name = ModelRESTViews._get_model_architecture_name(model_storage=model.model_storage)
236240
result = {
237241
ID_: model.id_,
238-
NAME: model.model_storage.name,
239-
ARCHITECTURE: model.model_storage.model_template.name,
242+
NAME: model_architecture_name,
243+
ARCHITECTURE: model_architecture_name,
240244
VERSION: model.version,
241245
CREATION_DATE: model.creation_date.isoformat(),
242246
SIZE: model.size,
@@ -313,7 +317,7 @@ def optimized_model_to_rest(optimized_model: Model, performance: Performance) ->
313317
filter(
314318
None,
315319
[
316-
optimized_model.model_storage.name,
320+
ModelRESTViews._get_model_architecture_name(model_storage=optimized_model.model_storage),
317321
model_format,
318322
precision_text,
319323
with_xai,
@@ -364,3 +368,14 @@ def optimized_model_to_rest(optimized_model: Model, performance: Performance) ->
364368
result[OPTIMIZATION_METHODS] = [method.name for method in optimized_model.optimization_methods]
365369

366370
return result
371+
372+
@staticmethod
373+
def _get_model_architecture_name(model_storage: ModelStorage) -> str:
374+
"""
375+
Get the model architecture name from the model storage, based on the model manifests.
376+
377+
:param model_storage: ModelStorage to get the architecture name from
378+
:return: The name of the model architecture
379+
"""
380+
model_manifest = SupportedModels.get_model_manifest_by_id(model_manifest_id=model_storage.model_manifest_id)
381+
return model_manifest.name

interactive_ai/services/resource/app/managers/project_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import cast
77

88
from geti_spicedb_tools import Permissions, SpiceDB
9+
from geti_supported_models.default_models import DefaultModels
910

1011
from communication.exceptions import DatasetStorageNotInProjectException, LabelNotFoundException, ProjectLockedException
1112
from managers.annotation_manager import AnnotationManager
@@ -86,6 +87,7 @@ def create_project(
8687
creator_id=creator_id,
8788
parser_class=project_parser,
8889
parser_kwargs=parser_kwargs,
90+
default_models_per_task=DefaultModels.get_default_models_per_task(),
8991
)
9092

9193
# TODO: CVS-89772 call spicedb before storing a project in database

interactive_ai/services/resource/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
"iai_core",
2121
"media_utils",
2222
"grpc_interfaces",
23+
"geti-supported-models",
2324
]
2425

2526
[tool.uv.sources]
@@ -32,6 +33,7 @@ geti_spicedb_tools = { path = "../../../libs/spicedb_tools", editable = true }
3233
iai_core = { path = "../../libs/iai_core_py", editable = true }
3334
media_utils = { path = "../../libs/media_utils", editable = true }
3435
grpc_interfaces = { path = "../../../libs/grpc_interfaces", editable = true }
36+
geti-supported-models = { path = "../../supported_models", editable = true }
3537

3638

3739
[dependency-groups]

interactive_ai/services/resource/tests/fixtures/database.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import cv2
1414
import numpy as np
1515
import pytest
16+
from geti_supported_models.default_models import DefaultModels
1617

1718
from communication.rest_parsers import RestProjectParser
1819
from service.label_schema_service import LabelSchemaService
@@ -446,6 +447,7 @@ def create_empty_project(
446447
creator_id="",
447448
parser_class=RestProjectParser,
448449
parser_kwargs=parser_kwargs,
450+
default_models_per_task=DefaultModels.get_default_models_per_task(),
449451
)
450452

451453
return self._project

interactive_ai/services/resource/tests/fixtures/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -431,13 +431,13 @@ def fxt_model_empty_dataset(
431431
@pytest.fixture
432432
def fxt_model_group_rest(fxt_model, fxt_mongo_id):
433433
yield {
434-
"id": fxt_mongo_id(0),
435-
"name": "Sample Detection Template",
436-
"task_id": fxt_mongo_id(1),
437-
"model_template_id": "test_template_detection",
434+
"id": str(fxt_mongo_id(0)),
435+
"name": fxt_model.model_storage.name,
436+
"task_id": str(fxt_mongo_id(1)),
437+
"model_template_id": fxt_model.model_storage.model_template_id,
438438
"models": [
439439
{
440-
"id": fxt_model.id_,
440+
"id": str(fxt_model.id_),
441441
"name": fxt_model.model_storage.name,
442442
"creation_date": "2020-01-01T00:00:00+00:00",
443443
"performance": {"score": get_performance_score(fxt_model.performance)},

0 commit comments

Comments
 (0)