Skip to content

Commit fccaa82

Browse files
authored
Bugfix: use correct default model and set intel gpu to true (#1035)
1 parent 5c448ea commit fccaa82

File tree

16 files changed

+24
-52
lines changed

16 files changed

+24
-52
lines changed

interactive_ai/libs/iai_core_py/iai_core/services/model_service.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ def _get_active_model_state(
178178
def get_or_create_model_storage(
179179
project_identifier: ProjectIdentifier,
180180
task_node: TaskNode,
181-
model_template_id: str,
181+
model_manifest_id: str,
182182
) -> ModelStorage:
183183
"""
184-
Returns the model storage for a particular task_node and model_template. If no
185-
model storage for that task node and model template exists, this method will
184+
Returns the model storage for a particular task_node and model_manifest. If no
185+
model storage for that task node and model manifest exists, this method will
186186
create one.
187187
188188
:param project_identifier: Identifier of the project containing the task node
189189
:param task_node: Task node associated with the model storage to get
190-
:param model_template_id: Identifier of the model template associated with
190+
:param model_manifest_id: Identifier of the model manifest associated with
191191
the model storage to get
192192
:return: ModelStorage
193193
"""
@@ -196,10 +196,10 @@ def get_or_create_model_storage(
196196
model_storages = model_storage_repo.get_by_task_node_id(task_node_id=task_node.id_)
197197

198198
for task_model_storage in model_storages:
199-
if task_model_storage.model_template.model_template_id == model_template_id:
199+
if task_model_storage.model_manifest_id == model_manifest_id:
200200
model_storage = task_model_storage
201201
if model_storage is None:
202-
model_template = ModelTemplateList().get_by_id(model_template_id)
202+
model_template = ModelTemplateList().get_by_id(model_manifest_id)
203203
if (
204204
isinstance(model_template, NullModelTemplate)
205205
or model_template.task_type != task_node.task_properties.task_type
@@ -210,7 +210,7 @@ def get_or_create_model_storage(
210210
if model_template.task_type == task_node.task_properties.task_type
211211
]
212212
raise ValueError(
213-
f"Algorithm with name '{model_template_id}' was not found for task "
213+
f"Algorithm with name '{model_manifest_id}' was not found for task "
214214
f"{task_node.title} of type {task_node.task_properties.task_type}. "
215215
f"Algorithms that are available to this task are: {available_algos}."
216216
)

interactive_ai/libs/iai_core_py/tests/services/test_model_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,29 +109,29 @@ def test_get_or_create_model_storage(
109109
model_storage = ModelService.get_or_create_model_storage(
110110
project_identifier=fxt_empty_project_persisted.identifier,
111111
task_node=fxt_detection_task,
112-
model_template_id=fxt_model_template_detection.model_template_id,
112+
model_manifest_id=fxt_model_template_detection.model_template_id,
113113
)
114114
request.addfinalizer(lambda: model_storage_repo.delete_by_id(model_storage.id_))
115115
# Test case 2: Model storage retrieval
116116
model_storage_exists = ModelService.get_or_create_model_storage(
117117
project_identifier=fxt_empty_project_persisted.identifier,
118118
task_node=fxt_detection_task,
119-
model_template_id=fxt_model_template_detection.model_template_id,
119+
model_manifest_id=fxt_model_template_detection.model_template_id,
120120
)
121121
# Test case 3: Raise error for invalid model template id
122122
invalid_template_id = "non existing model template id"
123123
with pytest.raises(ValueError) as exc_non_existing:
124124
ModelService.get_or_create_model_storage(
125125
project_identifier=fxt_empty_project_persisted.identifier,
126126
task_node=fxt_detection_task,
127-
model_template_id=invalid_template_id,
127+
model_manifest_id=invalid_template_id,
128128
)
129129
# Test case 4: Raise error for model template with non-matching task type
130130
with pytest.raises(ValueError) as exc_non_matching:
131131
ModelService.get_or_create_model_storage(
132132
project_identifier=fxt_empty_project_persisted.identifier,
133133
task_node=fxt_detection_task,
134-
model_template_id=fxt_model_template_classification.model_template_id,
134+
model_manifest_id=fxt_model_template_classification.model_template_id,
135135
)
136136

137137
# Assert

interactive_ai/services/director/app/communication/controllers/training_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _submit_train_job(project: Project, task_training_config: TrainingConfig, au
180180
model_storage = ModelService.get_or_create_model_storage(
181181
project_identifier=project.identifier,
182182
task_node=task_node,
183-
model_template_id=task_training_config.model_template_id,
183+
model_manifest_id=task_training_config.model_template_id,
184184
)
185185

186186
logger.info("Submitting train job for task with ID `%s`", task_node.id_)

interactive_ai/services/director/app/configuration/configuration_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def get_configuration_for_algorithm(
275275
model_storage = ModelService.get_or_create_model_storage(
276276
project_identifier=project.identifier,
277277
task_node=trainable_task,
278-
model_template_id=algorithm_name,
278+
model_manifest_id=algorithm_name,
279279
)
280280
except ValueError as error:
281281
raise AlgorithmNotFoundException(message=str(error))

interactive_ai/services/director/app/usecases/auto_train/kafka_handler.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
import logging
66

7+
from geti_supported_models.default_models import DefaultModels
8+
79
from coordination.dataset_manager.dynamic_required_num_annotations import DynamicRequiredAnnotations
810
from environment import get_gpu_provider
911

1012
from .auto_train_use_case import AutoTrainUseCase
1113
from geti_kafka_tools import BaseKafkaHandler, KafkaRawMessage, TopicSubscription
1214
from geti_telemetry_tools import unified_tracing
1315
from geti_types import ID, ProjectIdentifier, Singleton
14-
from iai_core.algorithms import ModelTemplateList
1516
from iai_core.repos import ProjectRepo
1617
from iai_core.services import ModelService
1718
from iai_core.session.session_propagation import setup_session_kafka
@@ -61,23 +62,19 @@ def on_project_created(self, raw_message: KafkaRawMessage) -> None:
6162
for task_node in project.get_trainable_task_nodes():
6263
task_type = task_node.task_properties.task_type
6364
try:
64-
default_model_template_id = next(
65-
model_template.model_template_id
66-
for model_template in ModelTemplateList().get_all()
67-
if model_template.task_type == task_type and model_template.is_default_for_task
68-
)
69-
except StopIteration:
70-
logger.error("Could not resolve default model template for task type %s", task_type)
65+
default_model_manifest_id = DefaultModels.get_default_model(task_type.name)
66+
except ValueError:
67+
logger.error("Could not resolve default model manifest for task type %s", task_type)
7168
return
7269
logger.info(
7370
"[WORKAROUND] Switching active model storage to '%s' for task '%s'",
74-
default_model_template_id,
71+
default_model_manifest_id,
7572
task_node,
7673
)
7774
model_storage = ModelService.get_or_create_model_storage(
7875
project_identifier=project_identifier,
7976
task_node=task_node,
80-
model_template_id=default_model_template_id,
77+
model_manifest_id=default_model_manifest_id,
8178
)
8279
ModelService.activate_model_storage(model_storage)
8380

interactive_ai/services/director/tests/unit/communication/rest_controllers/test_training_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_submit_train_jobs_new_model_storage(
306306
# Assert
307307
patched_get_model_storage.assert_called_once_with(
308308
task_node=fxt_detection_task,
309-
model_template_id=fxt_model_template_detection.model_template_id,
309+
model_manifest_id=fxt_model_template_detection.model_template_id,
310310
project_identifier=fxt_project.identifier,
311311
)
312312
mock_training_job_submit.assert_called_once_with(

interactive_ai/services/director/tests/unit/configuration/test_configuration_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_get_configuration_for_algorithm_uninitialized_model_storage(
107107
mock_save_hyper_parameters.assert_called_once_with(hyper_parameters)
108108
mock_get_model_storage.assert_called_once_with(
109109
task_node=detection_task,
110-
model_template_id=algorithm_name,
110+
model_manifest_id=algorithm_name,
111111
project_identifier=fxt_project_with_detection_task.identifier,
112112
)
113113

@@ -125,7 +125,7 @@ def test_get_configuration_for_algorithm_initialized_model_storage(
125125
):
126126
# Arrange
127127
detection_task = fxt_project_with_detection_task.tasks[-1]
128-
algorithm_name = "test_model_template_id"
128+
algorithm_name = fxt_model_template_detection.model_manifest_id
129129

130130
# Set a non-default value for one of the parameters to make sure we're not
131131
# just returning the defaults, upon comparison later on
@@ -138,7 +138,6 @@ def test_get_configuration_for_algorithm_initialized_model_storage(
138138
id_to_str=True,
139139
)
140140
)
141-
fxt_model_template_detection.model_template_id = algorithm_name
142141
fxt_model_storage_detection._model_template = fxt_model_template_detection
143142

144143
# Act

interactive_ai/supported_models/geti_supported_models/manifests/detection/dfine_x.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ stats:
1010
training_time: 1
1111
inference_speed: 1
1212

13-
supported_gpus:
14-
intel: false
15-
1613
hyperparameters:
1714
training:
1815
learning_rate: 0.00025

interactive_ai/supported_models/geti_supported_models/manifests/instance_segmentation/maskrcnn_efficientnetb2b.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ stats:
1010
training_time: 3
1111
inference_speed: 3
1212

13-
supported_gpus:
14-
intel: false
15-
1613
hyperparameters:
1714
training:
1815
learning_rate: 0.007

interactive_ai/supported_models/geti_supported_models/manifests/instance_segmentation/maskrcnn_r50_v1.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ stats:
1111
inference_speed: 2
1212
support_status: deprecated
1313

14-
supported_gpus:
15-
intel: false
16-
1714
hyperparameters:
1815
training:
1916
learning_rate: 0.007

0 commit comments

Comments
 (0)