Skip to content

Commit 5c448ea

Browse files
maxxgxCopilot
andauthored
Bugfix: subset split configuration not applied during training (#1004)
Co-authored-by: Copilot <[email protected]>
1 parent 42c4af4 commit 5c448ea

File tree

28 files changed

+555
-282
lines changed

28 files changed

+555
-282
lines changed

interactive_ai/services/auto_train/app/job_creation_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ def create_payload(self) -> dict:
116116
"retain_training_artifacts": self.retain_training_artifacts,
117117
}
118118
if FeatureFlagProvider.is_enabled(FeatureFlag.FEATURE_FLAG_NEW_CONFIGURABLE_PARAMETERS):
119-
payload["hyperparameters_json"] = (
119+
payload["training_configuration_json"] = (
120120
# Use model_dump_json to avoid int casting into floats
121-
self.training_configuration.hyperparameters.model_dump_json(
122-
exclude={"training": {"allowed_values_input_size"}}, exclude_none=True
121+
self.training_configuration.model_dump_json(
122+
exclude={"hyperparameters": {"training": {"allowed_values_input_size"}}}, exclude_none=True
123123
)
124124
if self.training_configuration
125125
else None

interactive_ai/services/director/app/communication/kafka_handler.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import logging
1010
import os
1111
from datetime import datetime
12-
from typing import TYPE_CHECKING
1312

1413
from communication.exceptions import MissingJobPayloadAttribute
1514
from metrics.instruments import (
@@ -19,20 +18,27 @@
1918
)
2019
from service.job_submission.job_creation_helpers import JobType
2120
from service.project_service import ProjectService
21+
from storage.repos.partial_training_configuration_repo import PartialTrainingConfigurationRepo
2222

2323
from geti_kafka_tools import BaseKafkaHandler, KafkaRawMessage, TopicSubscription
2424
from geti_telemetry_tools import unified_tracing
2525
from geti_types import CTX_SESSION_VAR, ID, ProjectIdentifier, Singleton
26+
from iai_core.entities.model import Model
2627
from iai_core.entities.model_storage import ModelStorageIdentifier
2728
from iai_core.entities.model_test_result import TestState
28-
from iai_core.repos import ModelRepo, ModelStorageRepo, ModelTestResultRepo, TaskNodeRepo
29+
from iai_core.entities.subset import Subset
30+
from iai_core.repos import (
31+
DatasetRepo,
32+
DatasetStorageRepo,
33+
ModelRepo,
34+
ModelStorageRepo,
35+
ModelTestResultRepo,
36+
TaskNodeRepo,
37+
)
2938
from iai_core.session.session_propagation import setup_session_kafka
3039
from iai_core.utils.deletion_helpers import DeletionHelpers
3140
from iai_core.utils.type_helpers import str2bool
3241

33-
if TYPE_CHECKING:
34-
from iai_core.entities.model import Model
35-
3642
logger = logging.getLogger(__name__)
3743

3844

@@ -100,6 +106,57 @@ def on_job_finished(self, raw_message: KafkaRawMessage) -> None:
100106
job_status=TrainingDurationCounterJobStatus.SUCCEEDED,
101107
)
102108

109+
# update training subset proportions with values used in the training job
110+
project_identifier = ProjectIdentifier(
111+
workspace_id=workspace_id,
112+
project_id=project_id,
113+
)
114+
self._update_subset_split_configuration(project_identifier=project_identifier, model=base_model)
115+
116+
@staticmethod
117+
def _update_subset_split_configuration(project_identifier: ProjectIdentifier, model: Model) -> None:
118+
"""
119+
Update the subset split configuration with actual values used after model training job.
120+
121+
:param project_identifier: The identifier of the project
122+
:param model: The model containing the dataset information to update splits from
123+
"""
124+
dataset_storage = DatasetStorageRepo(project_identifier).get_one(extra_filter={"use_for_training": True})
125+
dataset_repo = DatasetRepo(dataset_storage.identifier)
126+
subsets_count = dataset_repo.count_per_subset(dataset_id=model.train_dataset_id)
127+
training_config_repo = PartialTrainingConfigurationRepo(project_identifier)
128+
training_config = training_config_repo.get_by_model_manifest_id(
129+
model_manifest_id=model.model_storage.model_manifest_id
130+
)
131+
n_training = subsets_count.get(Subset.TRAINING.name, 0)
132+
n_validation = subsets_count.get(Subset.VALIDATION.name, 0)
133+
n_test = subsets_count.get(Subset.TESTING.name, 0)
134+
135+
total = n_training + n_validation + n_test
136+
137+
if total == 0:
138+
logger.warning(
139+
f"Cannot update subset split configuration for project {project_identifier}: "
140+
"total number of samples is zero. Setting all split percentages to zero."
141+
)
142+
training_percent = 0
143+
validation_percent = 0
144+
test_percent = 0
145+
else:
146+
# Calculate percentages (0-100 range)
147+
validation_percent = int(100 * n_validation / total)
148+
test_percent = int(100 * n_test / total)
149+
# Ensure percentages sum to exactly 100
150+
training_percent = 100 - validation_percent - test_percent
151+
152+
# Update the configuration with percentages
153+
training_config.global_parameters.dataset_preparation.subset_split.training = training_percent
154+
training_config.global_parameters.dataset_preparation.subset_split.validation = validation_percent
155+
training_config.global_parameters.dataset_preparation.subset_split.test = test_percent
156+
157+
# Save the updated configuration
158+
training_config_repo.save(training_config)
159+
103160
@setup_session_kafka
104161
@unified_tracing
105162
def on_job_failed(self, raw_message: KafkaRawMessage) -> None:

interactive_ai/services/director/app/service/job_submission/job_creation_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def create_payload(self) -> dict:
162162
"retain_training_artifacts": self.retain_training_artifacts,
163163
}
164164
if FeatureFlagProvider.is_enabled(FeatureFlag.FEATURE_FLAG_NEW_CONFIGURABLE_PARAMETERS):
165-
payload["hyperparameters_json"] = (
165+
payload["training_configuration_json"] = (
166166
# Use model_dump_json to avoid int casting into floats
167-
self.training_configuration.hyperparameters.model_dump_json(
168-
exclude={"training": {"allowed_values_input_size"}}, exclude_none=True
167+
self.training_configuration.model_dump_json(
168+
exclude={"hyperparameters": {"training": {"allowed_values_input_size"}}}, exclude_none=True
169169
)
170170
if self.training_configuration
171171
else None

interactive_ai/services/director/tests/fixtures/training_configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def fxt_training_configuration_task_level_rest_view(fxt_training_configuration_t
210210
"value": 10,
211211
},
212212
{
213-
"default_value": True,
213+
"default_value": False,
214214
"description": "Whether to automatically select data for each subset",
215215
"key": "auto_selection",
216216
"name": "Auto selection",
@@ -503,7 +503,7 @@ def fxt_training_configuration_full_rest_view(
503503
"value": 10,
504504
},
505505
{
506-
"default_value": True,
506+
"default_value": False,
507507
"description": "Whether to automatically select data for each subset",
508508
"key": "auto_selection",
509509
"name": "Auto selection",

interactive_ai/services/director/tests/unit/test_job_kafka_handler.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
from unittest.mock import MagicMock, patch
66

77
import pytest
8+
from geti_configuration_tools.training_configuration import GlobalParameters, TrainingConfiguration
89
from tests.unit.mocked_method_helpers import return_none
910

1011
from communication.kafka_handler import JobKafkaHandler
1112
from service.job_submission.job_creation_helpers import JobType
1213
from service.project_service import ProjectService
14+
from storage.repos.partial_training_configuration_repo import PartialTrainingConfigurationRepo
1315

1416
from geti_kafka_tools import KafkaRawMessage
1517
from geti_types import ID, ProjectIdentifier
1618
from iai_core.entities.model import Model
1719
from iai_core.entities.model_storage import ModelStorage
18-
from iai_core.repos import AnnotationSceneRepo, ModelRepo, ModelTestResultRepo
20+
from iai_core.entities.subset import Subset
21+
from iai_core.repos import AnnotationSceneRepo, DatasetRepo, ModelRepo, ModelTestResultRepo
1922
from iai_core.utils.deletion_helpers import DeletionHelpers
2023

2124
WORKSPACE_ID = "63b183d00000000000000001"
@@ -81,6 +84,10 @@ def test_on_training_finished(
8184
model_storage_id = ID("model_storage_id")
8285
model_id = ID("model_id")
8386
mocked_get_model_by_id.return_value = fxt_model
87+
project_identifier = ProjectIdentifier(
88+
workspace_id=ID(WORKSPACE_ID),
89+
project_id=project_id,
90+
)
8491

8592
MagicMock(spec=ModelStorage)
8693
mock_base_model = MagicMock(spec=Model)
@@ -90,12 +97,15 @@ def test_on_training_finished(
9097
mocked_get_optimized_models.return_value = mock_optimized_models
9198

9299
# Act
93-
with patch.object(ProjectService, "unlock") as mock_unlock_project:
100+
with (
101+
patch.object(ProjectService, "unlock") as mock_unlock_project,
102+
patch.object(JobKafkaHandler, "_update_subset_split_configuration") as mock_config_update,
103+
):
94104
fxt_job_kafka_handler.on_job_finished(
95105
fxt_consumer_record_maker(
96106
{
97107
"job_type": job_type,
98-
"workspace_id": ID("workspace_id"),
108+
"workspace_id": project_identifier.workspace_id,
99109
"job_payload": {
100110
"project_id": project_id,
101111
"task_id": task_id,
@@ -121,6 +131,10 @@ def test_on_training_finished(
121131
model=mock_base_model,
122132
training_job_duration=(end_time - start_time).total_seconds(),
123133
)
134+
mock_config_update.assert_called_once_with(
135+
project_identifier=project_identifier,
136+
model=mock_base_model,
137+
)
124138

125139
@patch.object(DeletionHelpers, "delete_models_by_base_model_id")
126140
def test_on_training_cancelled(
@@ -449,3 +463,64 @@ def test_on_optimize_job_cancelled(
449463
# Assert
450464
mock_unlock_project(job_type=job_type, project_id=project_id)
451465
mock_delete.assert_called_once_with(mock_model.id_)
466+
467+
def test_update_subset_split_configuration(self) -> None:
468+
# Arrange
469+
project_identifier = ProjectIdentifier(
470+
workspace_id=ID(WORKSPACE_ID),
471+
project_id=ID("project_id"),
472+
)
473+
474+
# Mock model
475+
mock_model = MagicMock(spec=Model)
476+
mock_model.train_dataset_id = ID("dataset_id")
477+
mock_model.model_storage = MagicMock()
478+
mock_model.model_storage.model_manifest_id = "YOLOX"
479+
480+
mock_training_config = MagicMock(spec=TrainingConfiguration)
481+
mock_training_config.global_parameters = MagicMock(spec=GlobalParameters)
482+
mock_training_config.global_parameters.dataset_preparation = MagicMock()
483+
mock_training_config.global_parameters.dataset_preparation.subset_split = MagicMock()
484+
mock_training_config.global_parameters.dataset_preparation.subset_split.training = 70
485+
mock_training_config.global_parameters.dataset_preparation.subset_split.validation = 20
486+
mock_training_config.global_parameters.dataset_preparation.subset_split.test = 10
487+
488+
# Set up mock data with imbalanced distribution that won't sum to 100 naturally
489+
subset_counts = {
490+
Subset.TRAINING.name: 155,
491+
Subset.VALIDATION.name: 37,
492+
Subset.TESTING.name: 18,
493+
}
494+
495+
# Act
496+
with (
497+
patch.object(DatasetRepo, "count_per_subset", return_value=subset_counts),
498+
patch.object(
499+
PartialTrainingConfigurationRepo,
500+
"get_by_model_manifest_id",
501+
return_value=mock_training_config,
502+
) as mock_get_by_model_manifest_id,
503+
patch.object(
504+
PartialTrainingConfigurationRepo,
505+
"save",
506+
) as mock_save_config,
507+
):
508+
JobKafkaHandler._update_subset_split_configuration(project_identifier=project_identifier, model=mock_model)
509+
510+
# Assert
511+
# Verify repositories were called correctly
512+
mock_get_by_model_manifest_id.assert_called_once_with(
513+
model_manifest_id=mock_model.model_storage.model_manifest_id
514+
)
515+
516+
# Calculate expected values based on mock data
517+
expected_validation = 17
518+
expected_test = 8
519+
expected_training = 75
520+
521+
subset_split = mock_training_config.global_parameters.dataset_preparation.subset_split
522+
assert subset_split.training == expected_training
523+
assert subset_split.validation == expected_validation
524+
assert subset_split.test == expected_test
525+
526+
mock_save_config.assert_called_once_with(mock_training_config)

interactive_ai/workflows/common/jobs_common/utils/dataset_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
import os
77

8+
from geti_configuration_tools.training_configuration import TrainingConfiguration
89
from geti_kafka_tools import publish_event
910
from geti_telemetry_tools import unified_tracing
1011
from geti_types import CTX_SESSION_VAR, ID, DatasetStorageIdentifier, ProjectIdentifier
@@ -211,6 +212,7 @@ def construct_and_save_train_dataset_for_task(
211212
project_id: ID,
212213
task_node: TaskNode,
213214
dataset_storage: DatasetStorage,
215+
training_configuration: TrainingConfiguration,
214216
max_training_dataset_size: int | None = None,
215217
reshuffle_subsets: bool = False,
216218
) -> Dataset:
@@ -225,6 +227,7 @@ def construct_and_save_train_dataset_for_task(
225227
:param project_id: ID of the project
226228
:param task_node: Task node for which the dataset is fetched
227229
:param dataset_storage: DatasetStorage containing the dataset items
230+
:param training_configuration: Training configuration containing dataset preparation parameters
228231
:param max_training_dataset_size: maximum training dataset size
229232
:param reshuffle_subsets: Whether to reassign/shuffle all the items to subsets including Test set from scratch
230233
:return: A copy of the current dataset, split into subsets.
@@ -253,6 +256,8 @@ def construct_and_save_train_dataset_for_task(
253256
TaskSubsetManager.split(
254257
dataset_items=iter(training_dataset_items),
255258
task_node=task_node,
259+
subset_split_config=training_configuration.global_parameters.dataset_preparation.subset_split,
260+
filtering_config=training_configuration.global_parameters.dataset_preparation.filtering,
256261
subsets_to_reset=subsets_to_reset,
257262
)
258263
task_dataset_entity.save_subsets(dataset=dataset, dataset_storage_identifier=dataset_storage.identifier)

0 commit comments

Comments
 (0)