diff --git a/application/backend/app/alembic/versions/da385d690aae_schema.py b/application/backend/app/alembic/versions/da385d690aae_schema.py index e7b93e9e47..727533d194 100644 --- a/application/backend/app/alembic/versions/da385d690aae_schema.py +++ b/application/backend/app/alembic/versions/da385d690aae_schema.py @@ -138,6 +138,19 @@ def upgrade() -> None: sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="RESTRICT"), sa.PrimaryKeyConstraint("project_id"), ) + op.create_table( + "training_configurations", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("project_id", sa.Text(), nullable=False), + sa.Column("model_architecture_id", sa.String(length=255), nullable=True), + sa.Column("configuration_data", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("project_id", "model_architecture_id", name="uq_project_model_config"), + ) + # ### end Alembic commands ### @@ -151,5 +164,6 @@ def downgrade() -> None: op.drop_table("dataset_revisions") op.drop_table("sources") op.drop_table("sinks") + op.drop_table("training_configurations") op.drop_table("projects") # ### end Alembic commands ### diff --git a/application/backend/app/api/dependencies.py b/application/backend/app/api/dependencies.py index c1e4b2a225..cdecdd4e91 100644 --- a/application/backend/app/api/dependencies.py +++ b/application/backend/app/api/dependencies.py @@ -24,6 +24,7 @@ from app.services.base_weights_service import BaseWeightsService from app.services.data_collect import DataCollector from app.services.label_service import LabelService +from app.services.training_configuration_service import TrainingConfigurationService from app.webrtc.manager import WebRTCManager @@ -198,3 +199,8 @@ def get_label_service(db: Annotated[Session, Depends(get_db)]) -> LabelService: def get_base_weights_service(data_dir: Annotated[Path, Depends(get_data_dir)]) -> BaseWeightsService: """Provides a BaseWeightsService instance for managing base weights.""" return BaseWeightsService(data_dir) + + +def get_training_configuration_service(db: Annotated[Session, Depends(get_db)]) -> TrainingConfigurationService: + """Provides a TrainingConfigurationService instance for managing training configurations.""" + return TrainingConfigurationService(db_session=db) diff --git a/application/backend/app/api/endpoints/projects.py b/application/backend/app/api/endpoints/projects.py index 95e89ee3c6..06e7b3c6cb 100644 --- a/application/backend/app/api/endpoints/projects.py +++ b/application/backend/app/api/endpoints/projects.py @@ -7,13 +7,13 @@ from typing import Annotated from uuid import UUID -from fastapi import APIRouter, Body, Depends, Query, status +from fastapi import APIRouter, Body, Depends, status from fastapi.exceptions import HTTPException from fastapi.openapi.models import Example from starlette.responses import FileResponse from app.api.dependencies import get_data_collector, get_label_service, get_project_id, get_project_service -from app.schemas import Label, PatchLabels, ProjectCreate, ProjectUpdateName, ProjectView, TrainingConfiguration +from app.schemas import Label, PatchLabels, ProjectCreate, ProjectUpdateName, ProjectView from app.services import ( LabelService, ProjectService, @@ -22,7 +22,6 @@ ResourceNotFoundError, ) from app.services.data_collect import DataCollector -from app.supported_models.hyperparameters import Hyperparameters logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/projects", tags=["Projects"]) @@ -265,82 +264,3 @@ def capture_next_pipeline_frame( data_collector.collect_next_frame() except ResourceNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - - -@router.get( - "/{project_id}/training_configuration", - response_model=TrainingConfiguration, - responses={ - status.HTTP_200_OK: {"description": "Training configuration found"}, - status.HTTP_400_BAD_REQUEST: {"description": "Invalid project ID or query parameters"}, - status.HTTP_404_NOT_FOUND: {"description": "Project not found"}, - }, -) -def get_training_configuration( - project_id: Annotated[UUID, Depends(get_project_id)], - project_service: Annotated[ProjectService, Depends(get_project_service)], - model_architecture_id: Annotated[str | None, Query()] = None, - model_revision_id: Annotated[UUID | None, Query()] = None, -) -> TrainingConfiguration: - """ - Get the training configuration for a project. - - - If model_architecture_id is provided, returns configuration for that specific model architecture. - - If model_revision_id is provided, returns configuration for a specific trained model. - - If neither is provided, returns only general task-related configuration. - Note: model_architecture_id and model_revision_id cannot be used together. - - Args: - project_id (UUID): The unique identifier of the project. - project_service (ProjectService): The project service - model_architecture_id (Optional[str]): The model architecture ID for specific configuration retrieval. - model_revision_id (Optional[UUID]): The model revision ID for specific configuration retrieval. - - Returns: - TrainingConfiguration: The training configuration details. - """ - try: - # TODO: Implement actual training configuration retrieval logic - _ = project_id, project_service, model_architecture_id, model_revision_id - return TrainingConfiguration.from_hyperparameters(hyperparameters=Hyperparameters()) - except ResourceNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - - -@router.patch( - "/{project_id}/training_configuration", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Training configuration updated successfully"}, - status.HTTP_400_BAD_REQUEST: {"description": "Invalid project ID, query parameters, or request body"}, - status.HTTP_404_NOT_FOUND: {"description": "Project not found"}, - }, -) -def update_training_configuration( - project_id: Annotated[UUID, Depends(get_project_id)], - training_config_update: Annotated[TrainingConfiguration, Body(description="Training configuration updates")], - project_service: Annotated[ProjectService, Depends(get_project_service)], - model_architecture_id: Annotated[str | None, Query()] = None, -) -> None: - """ - Update the training configuration for a project. - - - If model_architecture_id is provided, updates configuration for that specific model architecture. - - If not provided, updates the general task-related configuration. - Note: model_architecture_id cannot be used with model_revision_id for updates. - - Args: - project_id (UUID): The unique identifier of the project. - training_config_update (TrainingConfiguration): The training configuration updates. - project_service (ProjectService): The project service - model_architecture_id (Optional[str]): The model architecture ID for specific configuration update. - """ - try: - # TODO: Implement actual training configuration update logic - _ = project_id, training_config_update, project_service, model_architecture_id - except ResourceNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) diff --git a/application/backend/app/api/endpoints/training_configurations.py b/application/backend/app/api/endpoints/training_configurations.py new file mode 100644 index 0000000000..325fb25968 --- /dev/null +++ b/application/backend/app/api/endpoints/training_configurations.py @@ -0,0 +1,97 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status + +from app.api.dependencies import get_training_configuration_service +from app.schemas import TrainingConfiguration +from app.services import ResourceNotFoundError +from app.services.training_configuration_service import TrainingConfigurationService + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/projects/{project_id}/training_configuration", tags=["Training Configuration"]) + + +@router.get("", response_model=TrainingConfiguration) +def get_training_configuration( + training_configuration_service: Annotated[ + TrainingConfigurationService, Depends(get_training_configuration_service) + ], + project_id: UUID, + model_architecture_id: str | None = None, + model_revision_id: UUID | None = None, +) -> TrainingConfiguration: + """ + Get the training configuration for a project. + + If model_architecture_id is provided, returns configuration for that specific model architecture. + If model_revision_id is provided, returns configuration for a specific trained model. + If neither is provided, returns only general task-related configuration. + Note: model_architecture_id and model_revision_id cannot be used together. + + Args: + training_configuration_service (TrainingConfigurationService): The training configuration service. + project_id (UUID): The unique identifier of the project. + model_architecture_id (Optional[str]): The model architecture ID for specific configuration retrieval. + model_revision_id (Optional[UUID]): The model revision ID for specific configuration retrieval. + + Returns: + TrainingConfiguration: The training configuration details. + """ + try: + return training_configuration_service.get_training_configuration( + project_id=project_id, + model_architecture_id=model_architecture_id, + model_revision_id=model_revision_id, + ) + except ResourceNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + +@router.patch("", status_code=status.HTTP_204_NO_CONTENT) +def update_training_configuration( + training_configuration_service: Annotated[ + TrainingConfigurationService, Depends(get_training_configuration_service) + ], + project_id: UUID, + training_config_update: dict, + model_architecture_id: str | None = None, +) -> None: + """ + Update the training configuration for a project. + + - If model_architecture_id is provided, updates configuration for that specific model architecture. + - If not provided, updates the general task-related configuration. + Note: model_architecture_id cannot be used with model_revision_id for updates. + + Request body should contain elements of the configuration hierarchy to update: + ```json + { + "dataset_preparation": {...}, + "training": {...}, + "evaluation": {...} + } + ``` + + Args: + training_configuration_service (TrainingConfigurationService): The training configuration service. + project_id (UUID): The unique identifier of the project. + training_config_update (dict): The configuration updates to apply. + model_architecture_id (Optional[str]): The model architecture ID for specific configuration update. + """ + try: + training_configuration_service.update_training_configuration( + project_id=project_id, + training_config_update=training_config_update, + model_architecture_id=model_architecture_id, + ) + except ResourceNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) diff --git a/application/backend/app/db/schema.py b/application/backend/app/db/schema.py index 373b93096d..b7536b31bb 100644 --- a/application/backend/app/db/schema.py +++ b/application/backend/app/db/schema.py @@ -122,3 +122,14 @@ class LabelDB(BaseID): hotkey: Mapped[str | None] = mapped_column(String(10), nullable=True) project = relationship("ProjectDB", back_populates="labels") + + +class TrainingConfigurationDB(BaseID): + __tablename__ = "training_configurations" + __table_args__ = (UniqueConstraint("project_id", "model_architecture_id", name="uq_project_model_config"),) + + project_id: Mapped[str] = mapped_column(Text, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) + model_architecture_id: Mapped[str | None] = mapped_column(String(255), nullable=True) # NULL for general config + configuration_data: Mapped[dict] = mapped_column(JSON, nullable=False) + + project = relationship("ProjectDB") diff --git a/application/backend/app/repositories/training_configuration_repo.py b/application/backend/app/repositories/training_configuration_repo.py new file mode 100644 index 0000000000..abcc1fd4c0 --- /dev/null +++ b/application/backend/app/repositories/training_configuration_repo.py @@ -0,0 +1,68 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.db.schema import TrainingConfigurationDB +from app.repositories.base import BaseRepository + + +class TrainingConfigurationRepository(BaseRepository[TrainingConfigurationDB]): + def __init__(self, db: Session) -> None: + super().__init__(db, TrainingConfigurationDB) + + def get_by_project_and_model_architecture( + self, + project_id: str, + model_architecture_id: str | None = None, + ) -> TrainingConfigurationDB | None: + """ + Get training configuration by project ID and optional model architecture ID. + + Args: + project_id (str): The ID of the project. + model_architecture_id (str | None): The ID of the model architecture. + """ + stmt = select(TrainingConfigurationDB).where( + TrainingConfigurationDB.project_id == project_id, + TrainingConfigurationDB.model_architecture_id == model_architecture_id, + ) + return self.db.execute(stmt).scalar_one_or_none() + + def create_or_update( + self, + project_id: str, + model_architecture_id: str | None, + configuration_data: dict, + ) -> TrainingConfigurationDB: + """ + Create or update a training configuration. + + If a configuration for the given project and model architecture exists, it is updated. + Otherwise, a new configuration is created. + + Args: + project_id (str): The ID of the project. + model_architecture_id (str | None): The ID of the model architecture. + configuration_data (dict): The configuration data to store. + + Returns: + TrainingConfigurationDB: The created or updated training configuration. + """ + existing = self.get_by_project_and_model_architecture( + project_id=project_id, + model_architecture_id=model_architecture_id, + ) + + if existing: + existing.configuration_data = configuration_data + self.save(existing) + return existing + + new_config = TrainingConfigurationDB( + project_id=project_id, + model_architecture_id=model_architecture_id, + configuration_data=configuration_data, + ) + self.save(new_config) + return new_config diff --git a/application/backend/app/schemas/training_configuration.py b/application/backend/app/schemas/training_configuration.py index df9edc420a..5d70039259 100644 --- a/application/backend/app/schemas/training_configuration.py +++ b/application/backend/app/schemas/training_configuration.py @@ -27,20 +27,62 @@ class TrainingConfiguration(BaseModel): @classmethod def from_hyperparameters(cls, hyperparameters: Hyperparameters) -> "TrainingConfiguration": - """Create TrainingConfiguration from Hyperparameters""" + """Create TrainingConfiguration from a ModelManifest's Hyperparameters""" return cls( dataset_preparation=hyperparameters.dataset_preparation, training=hyperparameters.training, evaluation=hyperparameters.evaluation, ) - # TODO: Add example values #4799 + @classmethod + def from_model(cls, model_config: dict) -> "TrainingConfiguration": + """Create TrainingConfiguration from model configuration dictionary""" + return cls.model_validate(model_config) + model_config = { "json_schema_extra": { "example": { - "dataset_preparation": {"config": {"train_val_split_ratio": 0.8, "augmentation_enabled": True}}, - "training": {"config": {"epochs": 100, "batch_size": 32, "learning_rate": 0.001}}, - "evaluation": {"config": {"metrics": ["accuracy", "precision", "recall"], "validation_split": 0.2}}, + "dataset_preparation": { + "augmentation": { + "topdown_affine": None, + "random_zoom_out": None, + "iou_random_crop": {"enable": True}, + "mosaic": None, + "random_resize_crop": None, + "random_affine": { + "enable": False, + "max_rotate_degree": 10, + "max_translate_ratio": 0.1, + "scaling_ratio_range": [0.5, 1.5], + "max_shear_degree": 2, + }, + "mixup": None, + "hsv_random_aug": None, + "random_horizontal_flip": {"enable": True, "probability": 0.5}, + "random_vertical_flip": {"enable": False, "probability": 0.5}, + "color_jitter": { + "enable": False, + "brightness": [0.875, 1.125], + "contrast": [0.5, 1.5], + "saturation": [0.5, 1.5], + "hue": [-0.05, 0.05], + "probability": 0.5, + }, + "gaussian_blur": {"enable": False, "kernel_size": 5, "sigma": [0.1, 2], "probability": 0.5}, + "photometric_distort": None, + "gaussian_noise": {"enable": False, "mean": 0, "sigma": 0.1, "probability": 0.5}, + "tiling": {"enable": False, "adaptive_tiling": True, "tile_size": 400, "tile_overlap": 0.2}, + } + }, + "training": { + "max_epochs": 200, + "early_stopping": {"enable": True, "patience": 10}, + "learning_rate": 0.004, + "input_size_width": 992, + "input_size_height": 800, + "allowed_values_input_size": [128, 256, 384, 512, 640, 800, 992, 1024], + }, + "evaluation": {"metric": "f_measure"}, } } } diff --git a/application/backend/app/services/training_configuration_service.py b/application/backend/app/services/training_configuration_service.py new file mode 100644 index 0000000000..d437385728 --- /dev/null +++ b/application/backend/app/services/training_configuration_service.py @@ -0,0 +1,115 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from uuid import UUID + +from sqlalchemy.orm import Session + +from app.repositories import ModelRevisionRepository, ProjectRepository +from app.repositories.training_configuration_repo import TrainingConfigurationRepository +from app.schemas import TrainingConfiguration +from app.services import ResourceNotFoundError, ResourceType +from app.supported_models import SupportedModels +from app.supported_models.default_models import DefaultModels + + +class TrainingConfigurationService: + def __init__(self, db_session: Session) -> None: + self._db_session = db_session + self._training_config_repo = TrainingConfigurationRepository(db_session) + + def get_training_configuration( + self, + project_id: UUID, + model_architecture_id: str | None = None, + model_revision_id: UUID | None = None, + ) -> TrainingConfiguration: + """ + Retrieves training configuration. + + If model_revision_id is provided, the configuration is loaded from the model entity. + If model_architecture_id is provided, tries to load from database first, then from manifest. + Otherwise, loads default configuration based on project's task type. + + Args: + project_id (UUID): Identifier for the project. + model_architecture_id (str | None): Optional ID of the model architecture to retrieve configurations. + model_revision_id (UUID | None): Optional ID of the model revision to retrieve specific configurations. + + Returns: + TrainingConfiguration: The training configuration object. + """ + if model_architecture_id and model_revision_id: + raise ValueError("Only one of model_architecture_id or model_revision_id should be provided.") + + if model_revision_id: + model = ModelRevisionRepository(self._db_session).get_by_id(str(model_revision_id)) + if not model: + raise ResourceNotFoundError(ResourceType.MODEL, str(model_revision_id)) + return TrainingConfiguration.from_model(model_config=model.training_configuration) + + if model_architecture_id: + stored_config = TrainingConfigurationRepository(self._db_session).get_by_project_and_model_architecture( + project_id=str(project_id), + model_architecture_id=model_architecture_id, + ) + if stored_config: + return TrainingConfiguration.model_validate(stored_config.configuration_data) + + model_manifest = SupportedModels.get_model_manifest_by_id(model_architecture_id) + return TrainingConfiguration.from_hyperparameters(hyperparameters=model_manifest.hyperparameters) + + # Load default configuration based on the project's task type + project = ProjectRepository(self._db_session).get_by_id(str(project_id)) + if not project: + raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id)) + + default_model_id = DefaultModels.get_default_model(task_type=project.task_type) + if not default_model_id: + raise ValueError(f"No default model found for task type {project.task_type}") + default_model_manifest = SupportedModels.get_model_manifest_by_id(model_manifest_id=default_model_id) + + return TrainingConfiguration.from_hyperparameters(hyperparameters=default_model_manifest.hyperparameters) + + def update_training_configuration( + self, + project_id: UUID, + training_config_update: dict, + model_architecture_id: str | None = None, + ) -> TrainingConfiguration: + """ + Updates training configuration with provided changes. + + Args: + project_id (UUID): Identifier for the project. + training_config_update (dict): Configuration updates to apply. + model_architecture_id (str | None): Optional ID of the model architecture. + + Returns: + TrainingConfiguration: The updated training configuration object. + """ + project = ProjectRepository(self._db_session).get_by_id(str(project_id)) + if not project: + raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id)) + + current_config = self.get_training_configuration( + project_id=project_id, + model_architecture_id=model_architecture_id, + ) + + validated_update_config = TrainingConfiguration.from_model(training_config_update) + updated_config = current_config.model_copy( + update={ + "dataset_preparation": validated_update_config.dataset_preparation, + "training": validated_update_config.training, + "evaluation": validated_update_config.evaluation, + }, + deep=True, + ) + + self._training_config_repo.create_or_update( + project_id=str(project_id), + model_architecture_id=model_architecture_id, + configuration_data=updated_config.model_dump(), + ) + + return updated_config diff --git a/application/backend/tests/integration/alembic/test_initial_schema_migration.py b/application/backend/tests/integration/alembic/test_initial_schema_migration.py index 480851b9a8..6dcb12fec3 100644 --- a/application/backend/tests/integration/alembic/test_initial_schema_migration.py +++ b/application/backend/tests/integration/alembic/test_initial_schema_migration.py @@ -47,7 +47,7 @@ def test_database_migration_applied(alembic_session): result = alembic_session.execute(text("SELECT name FROM sqlite_master WHERE type='table'")) tables = [row[0] for row in result.fetchall()] - assert len(tables) == 9 + assert len(tables) == 10 assert "alembic_version" in tables assert "projects" in tables assert "sinks" in tables @@ -57,6 +57,7 @@ def test_database_migration_applied(alembic_session): assert "dataset_revisions" in tables assert "dataset_items" in tables assert "labels" in tables + assert "training_configurations" in tables (result,) = alembic_session.execute(text("SELECT version_num FROM alembic_version")).fetchone() assert result == "da385d690aae" diff --git a/application/backend/tests/integration/services/test_training_configuration_service.py b/application/backend/tests/integration/services/test_training_configuration_service.py new file mode 100644 index 0000000000..3b33fd681a --- /dev/null +++ b/application/backend/tests/integration/services/test_training_configuration_service.py @@ -0,0 +1,200 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from uuid import UUID, uuid4 + +import pytest +from sqlalchemy.orm import Session + +from app.db.schema import ModelRevisionDB, ProjectDB, TrainingConfigurationDB +from app.schemas import TrainingConfiguration +from app.schemas.project import TaskType +from app.services import ResourceNotFoundError +from app.services.training_configuration_service import TrainingConfigurationService +from app.supported_models.hyperparameters import ( + DatasetPreparationParameters, + EvaluationParameters, + TrainingHyperParameters, +) + + +@pytest.fixture +def fxt_training_configuration() -> TrainingConfiguration: + """Create a mock training configuration.""" + return TrainingConfiguration( + dataset_preparation=DatasetPreparationParameters(), + training=TrainingHyperParameters(), + evaluation=EvaluationParameters(), + ) + + +@pytest.fixture +def fxt_training_configuration_service(db_session: Session) -> TrainingConfigurationService: + return TrainingConfigurationService(db_session) + + +class TestTrainingConfigurationService: + def test_get_training_configuration_by_model_revision_id( + self, fxt_training_configuration, fxt_training_configuration_service, db_session + ): + """Test getting training configuration by model revision ID.""" + project = ProjectDB( + id=str(uuid4()), + name="Test Detection Project", + task_type=TaskType.DETECTION, + exclusive_labels=False, + ) + db_session.add(project) + model = ModelRevisionDB( + id=str(uuid4()), + project_id=project.id, + architecture="Object_Detection_YOLOv5", + training_status="running", + training_configuration=fxt_training_configuration.model_dump(), + label_schema_revision={}, + files_deleted=False, + ) + db_session.add(model) + db_session.flush() + + training_configuration = fxt_training_configuration_service.get_training_configuration( + project_id=UUID(project.id), model_revision_id=UUID(model.id) + ) + assert isinstance(training_configuration, TrainingConfiguration) + assert training_configuration == fxt_training_configuration + + def test_get_training_configuration_by_model_revision_id_not_found(self, fxt_training_configuration_service): + """Test getting training configuration with non-existent model revision.""" + with pytest.raises(ResourceNotFoundError): + fxt_training_configuration_service.get_training_configuration(project_id=uuid4(), model_revision_id=uuid4()) + + def test_get_training_configuration_by_model_architecture_id_from_db( + self, fxt_training_configuration, fxt_training_configuration_service, db_session + ): + """Test getting training configuration by model architecture ID from database.""" + project = ProjectDB( + id=str(uuid4()), + name="Test Detection Project", + task_type=TaskType.DETECTION, + exclusive_labels=False, + ) + db_session.add(project) + training_configuration = TrainingConfigurationDB( + id=str(uuid4()), + project_id=project.id, + model_architecture_id="Custom_Object_Detection_YOLOX", + configuration_data=fxt_training_configuration.model_dump(), + ) + db_session.add(training_configuration) + db_session.flush() + + training_configuration = fxt_training_configuration_service.get_training_configuration( + project_id=UUID(project.id), model_architecture_id="Custom_Object_Detection_YOLOX" + ) + assert isinstance(training_configuration, TrainingConfiguration) + assert training_configuration == fxt_training_configuration + + def test_get_training_configuration_by_model_architecture_id_from_manifest( + self, fxt_training_configuration, fxt_training_configuration_service, db_session + ): + """Test getting training configuration by model architecture ID from manifest.""" + training_configuration = fxt_training_configuration_service.get_training_configuration( + project_id=uuid4(), model_architecture_id="Custom_Object_Detection_YOLOX" + ) + assert isinstance(training_configuration, TrainingConfiguration) + assert training_configuration != fxt_training_configuration + + def test_get_training_configuration_default_by_project_type( + self, fxt_training_configuration, fxt_training_configuration_service, db_session + ): + """Test getting general training configuration from default model.""" + project = ProjectDB( + id=str(uuid4()), + name="Test Detection Project", + task_type=TaskType.DETECTION, + exclusive_labels=False, + ) + db_session.add(project) + db_session.flush() + + training_configuration = fxt_training_configuration_service.get_training_configuration( + project_id=UUID(project.id), model_architecture_id=None + ) + assert isinstance(training_configuration, TrainingConfiguration) + assert training_configuration != fxt_training_configuration + + def test_get_training_configuration_both_ids_provided_error(self, fxt_training_configuration_service): + """Test error when both model_architecture_id and model_revision_id are provided.""" + with pytest.raises(ValueError) as exc_info: + fxt_training_configuration_service.get_training_configuration( + project_id=uuid4(), model_architecture_id="test_arch", model_revision_id=uuid4() + ) + + assert "Only one of model_architecture_id or model_revision_id should be provided" in str(exc_info.value) + + def test_update_training_configuration_new( + self, fxt_training_configuration, fxt_training_configuration_service, db_session + ): + """Test updating a new training configuration.""" + project = ProjectDB( + id=str(uuid4()), + name="Test Detection Project", + task_type=TaskType.DETECTION, + exclusive_labels=False, + ) + db_session.add(project) + db_session.flush() + + training_config_update = { + "dataset_preparation": {"augmentation": {"topdown_affine": {"enable": True, "probability": 0.5}}}, + "training": {"max_epochs": 999}, + "evaluation": {"metric": "new_metric"}, + } + + training_configuration = fxt_training_configuration_service.update_training_configuration( + project_id=UUID(project.id), training_config_update=training_config_update, model_architecture_id=None + ) + assert isinstance(training_configuration, TrainingConfiguration) + assert training_configuration != fxt_training_configuration + assert training_configuration.dataset_preparation.augmentation.topdown_affine.enable is True + assert training_configuration.dataset_preparation.augmentation.topdown_affine.probability == 0.5 + assert training_configuration.training.max_epochs == 999 + assert training_configuration.evaluation.metric == "new_metric" + + def test_update_training_configuration_update( + self, fxt_training_configuration, fxt_training_configuration_service, db_session + ): + """Test updating an existing training configuration with new configuration.""" + project = ProjectDB( + id=str(uuid4()), + name="Test Detection Project", + task_type=TaskType.DETECTION, + exclusive_labels=False, + ) + db_session.add(project) + training_configuration = TrainingConfigurationDB( + id=str(uuid4()), + project_id=project.id, + model_architecture_id="Custom_Object_Detection_YOLOX", + configuration_data=fxt_training_configuration.model_dump(), + ) + db_session.add(training_configuration) + db_session.flush() + + training_config_update = { + "dataset_preparation": {"augmentation": {"topdown_affine": {"enable": True, "probability": 0.5}}}, + "training": {"max_epochs": 999}, + "evaluation": {"metric": "new_metric"}, + } + + training_configuration = fxt_training_configuration_service.update_training_configuration( + project_id=UUID(project.id), + training_config_update=training_config_update, + model_architecture_id="Custom_Object_Detection_YOLOX", + ) + assert isinstance(training_configuration, TrainingConfiguration) + assert training_configuration != fxt_training_configuration + assert training_configuration.dataset_preparation.augmentation.topdown_affine.enable is True + assert training_configuration.dataset_preparation.augmentation.topdown_affine.probability == 0.5 + assert training_configuration.training.max_epochs == 999 + assert training_configuration.evaluation.metric == "new_metric" diff --git a/application/backend/tests/unit/endpoints/test_training_configurations.py b/application/backend/tests/unit/endpoints/test_training_configurations.py new file mode 100644 index 0000000000..695e691043 --- /dev/null +++ b/application/backend/tests/unit/endpoints/test_training_configurations.py @@ -0,0 +1,51 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from app.api.dependencies import get_training_configuration_service +from app.main import app +from app.schemas import TrainingConfiguration +from app.services.training_configuration_service import TrainingConfigurationService +from app.supported_models.hyperparameters import ( + DatasetPreparationParameters, + EvaluationParameters, + TrainingHyperParameters, +) + + +@pytest.fixture +def fxt_training_configuration() -> TrainingConfiguration: + """Create a mock training configuration.""" + return TrainingConfiguration( + dataset_preparation=DatasetPreparationParameters(), + training=TrainingHyperParameters(), + evaluation=EvaluationParameters(), + ) + + +@pytest.fixture +def fxt_training_configuration_service() -> MagicMock: + training_configuration_service = MagicMock(spec=TrainingConfigurationService) + app.dependency_overrides[get_training_configuration_service] = lambda: training_configuration_service + return training_configuration_service + + +class TestTrainingConfigurationEndpoints: + def test_get_training_configuration_success( + self, fxt_client, fxt_training_configuration, fxt_training_configuration_service + ): + """Test successful retrieval of training configuration.""" + project_id = uuid4() + fxt_training_configuration_service.get_training_configuration.return_value = fxt_training_configuration + + response = fxt_client.get(f"/api/projects/{project_id}/training_configuration") + + assert response.status_code == 200 + data = response.json() + assert "dataset_preparation" in data + assert "training" in data + assert "evaluation" in data