Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions application/backend/app/alembic/versions/da385d690aae_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="RESTRICT"),
sa.PrimaryKeyConstraint("project_id"),
)
op.create_table(
"training_configurations",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("project_id", sa.Text(), nullable=False),
sa.Column("model_architecture_id", sa.String(length=100), nullable=True),
sa.Column("configuration_data", sa.JSON(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False),
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("project_id", "model_architecture_id", name="uq_project_model_config"),
)

# ### end Alembic commands ###


Expand All @@ -152,4 +165,5 @@ def downgrade() -> None:
op.drop_table("sources")
op.drop_table("sinks")
op.drop_table("projects")
op.drop_table("training_configurations")
# ### end Alembic commands ###
6 changes: 6 additions & 0 deletions application/backend/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from app.services.base_weights_service import BaseWeightsService
from app.services.data_collect import DataCollector
from app.services.label_service import LabelService
from app.services.training_configuration_service import TrainingConfigurationService
from app.webrtc.manager import WebRTCManager


Expand Down Expand Up @@ -198,3 +199,8 @@ def get_label_service(db: Annotated[Session, Depends(get_db)]) -> LabelService:
def get_base_weights_service(data_dir: Annotated[Path, Depends(get_data_dir)]) -> BaseWeightsService:
"""Provides a BaseWeightsService instance for managing base weights."""
return BaseWeightsService(data_dir)


def get_training_configuration_service(db: Annotated[Session, Depends(get_db)]) -> TrainingConfigurationService:
"""Provides a TrainingConfigurationService instance for managing training configurations."""
return TrainingConfigurationService(db_session=db)
84 changes: 2 additions & 82 deletions application/backend/app/api/endpoints/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Body, Depends, Query, status
from fastapi import APIRouter, Body, Depends, status
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import Example
from starlette.responses import FileResponse

from app.api.dependencies import get_data_collector, get_label_service, get_project_id, get_project_service
from app.schemas import Label, PatchLabels, ProjectCreate, ProjectUpdateName, ProjectView, TrainingConfiguration
from app.schemas import Label, PatchLabels, ProjectCreate, ProjectUpdateName, ProjectView
from app.services import (
LabelService,
ProjectService,
Expand All @@ -22,7 +22,6 @@
ResourceNotFoundError,
)
from app.services.data_collect import DataCollector
from app.supported_models.hyperparameters import Hyperparameters

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
Expand Down Expand Up @@ -265,82 +264,3 @@ def capture_next_pipeline_frame(
data_collector.collect_next_frame()
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))


@router.get(
"/{project_id}/training_configuration",
response_model=TrainingConfiguration,
responses={
status.HTTP_200_OK: {"description": "Training configuration found"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project ID or query parameters"},
status.HTTP_404_NOT_FOUND: {"description": "Project not found"},
},
)
def get_training_configuration(
project_id: Annotated[UUID, Depends(get_project_id)],
project_service: Annotated[ProjectService, Depends(get_project_service)],
model_architecture_id: Annotated[str | None, Query()] = None,
model_revision_id: Annotated[UUID | None, Query()] = None,
) -> TrainingConfiguration:
"""
Get the training configuration for a project.

- If model_architecture_id is provided, returns configuration for that specific model architecture.
- If model_revision_id is provided, returns configuration for a specific trained model.
- If neither is provided, returns only general task-related configuration.
Note: model_architecture_id and model_revision_id cannot be used together.

Args:
project_id (UUID): The unique identifier of the project.
project_service (ProjectService): The project service
model_architecture_id (Optional[str]): The model architecture ID for specific configuration retrieval.
model_revision_id (Optional[UUID]): The model revision ID for specific configuration retrieval.

Returns:
TrainingConfiguration: The training configuration details.
"""
try:
# TODO: Implement actual training configuration retrieval logic
_ = project_id, project_service, model_architecture_id, model_revision_id
return TrainingConfiguration.from_hyperparameters(hyperparameters=Hyperparameters())
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))


@router.patch(
"/{project_id}/training_configuration",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Training configuration updated successfully"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project ID, query parameters, or request body"},
status.HTTP_404_NOT_FOUND: {"description": "Project not found"},
},
)
def update_training_configuration(
project_id: Annotated[UUID, Depends(get_project_id)],
training_config_update: Annotated[TrainingConfiguration, Body(description="Training configuration updates")],
project_service: Annotated[ProjectService, Depends(get_project_service)],
model_architecture_id: Annotated[str | None, Query()] = None,
) -> None:
"""
Update the training configuration for a project.

- If model_architecture_id is provided, updates configuration for that specific model architecture.
- If not provided, updates the general task-related configuration.
Note: model_architecture_id cannot be used with model_revision_id for updates.

Args:
project_id (UUID): The unique identifier of the project.
training_config_update (TrainingConfiguration): The training configuration updates.
project_service (ProjectService): The project service
model_architecture_id (Optional[str]): The model architecture ID for specific configuration update.
"""
try:
# TODO: Implement actual training configuration update logic
_ = project_id, training_config_update, project_service, model_architecture_id
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status

from app.api.dependencies import get_training_configuration_service
from app.schemas import TrainingConfiguration
from app.services import ResourceNotFoundError
from app.services.training_configuration_service import TrainingConfigurationService

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects/{project_id}/training_configuration", tags=["Training Configuration"])


@router.get("", response_model=TrainingConfiguration)
def get_training_configuration(
training_configuration_service: Annotated[
TrainingConfigurationService, Depends(get_training_configuration_service)
],
project_id: UUID,
model_architecture_id: str | None = None,
model_revision_id: UUID | None = None,
) -> TrainingConfiguration:
"""
Get the training configuration for a project.

If model_architecture_id is provided, returns configuration for that specific model architecture.
If model_revision_id is provided, returns configuration for a specific trained model.
If neither is provided, returns only general task-related configuration.
Note: model_architecture_id and model_revision_id cannot be used together.

Args:
training_configuration_service (TrainingConfigurationService): The training configuration service.
project_id (UUID): The unique identifier of the project.
model_architecture_id (Optional[str]): The model architecture ID for specific configuration retrieval.
model_revision_id (Optional[UUID]): The model revision ID for specific configuration retrieval.

Returns:
TrainingConfiguration: The training configuration details.
"""
try:
return training_configuration_service.get_training_configuration(
project_id=project_id,
model_architecture_id=model_architecture_id,
model_revision_id=model_revision_id,
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))


@router.patch("", status_code=status.HTTP_204_NO_CONTENT)
def update_training_configuration(
training_configuration_service: Annotated[
TrainingConfigurationService, Depends(get_training_configuration_service)
],
project_id: UUID,
training_config_update: dict,
model_architecture_id: str | None = None,
) -> None:
"""
Update the training configuration for a project.

- If model_architecture_id is provided, updates configuration for that specific model architecture.
- If not provided, updates the general task-related configuration.
Note: model_architecture_id cannot be used with model_revision_id for updates.

Request body should contain elements of the configuration hierarchy to update:
```json
{
"dataset_preparation": {...},
"training": {...},
"evaluation": {...}
}
```

Args:
training_configuration_service (TrainingConfigurationService): The training configuration service.
project_id (UUID): The unique identifier of the project.
training_config_update (dict): The configuration updates to apply.
model_architecture_id (Optional[str]): The model architecture ID for specific configuration update.
"""
try:
training_configuration_service.update_training_configuration(
project_id=project_id,
training_config_update=training_config_update,
model_architecture_id=model_architecture_id,
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
11 changes: 11 additions & 0 deletions application/backend/app/db/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,14 @@ class LabelDB(BaseID):
hotkey: Mapped[str | None] = mapped_column(String(10), nullable=True)

project = relationship("ProjectDB", back_populates="labels")


class TrainingConfigurationDB(BaseID):
__tablename__ = "training_configurations"
__table_args__ = (UniqueConstraint("project_id", "model_architecture_id", name="uq_project_model_config"),)

project_id: Mapped[str] = mapped_column(Text, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
model_architecture_id: Mapped[str | None] = mapped_column(String(255), nullable=True) # NULL for general config
configuration_data: Mapped[dict] = mapped_column(JSON, nullable=False)

project = relationship("ProjectDB")
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from sqlalchemy.orm import Session

from app.db.schema import TrainingConfigurationDB
from app.repositories.base import BaseRepository


class TrainingConfigurationRepository(BaseRepository[TrainingConfigurationDB]):
def __init__(self, db: Session) -> None:
super().__init__(db, TrainingConfigurationDB)

def get_by_project_and_model_architecture(
self,
project_id: str,
model_architecture_id: str | None = None,
) -> TrainingConfigurationDB | None:
"""
Get training configuration by project ID and optional model architecture ID.
Args:
project_id (str): The ID of the project.
model_architecture_id (str | None): The ID of the model architecture.
"""
return (
self.db.query(TrainingConfigurationDB)
.filter(
TrainingConfigurationDB.project_id == project_id,
TrainingConfigurationDB.model_architecture_id == model_architecture_id,
)
.first()
)

def create_or_update(
self,
project_id: str,
model_architecture_id: str | None,
configuration_data: dict,
) -> TrainingConfigurationDB:
"""
Create or update a training configuration.
If a configuration for the given project and model architecture exists, it is updated.
Otherwise, a new configuration is created.
Args:
project_id (str): The ID of the project.
model_architecture_id (str | None): The ID of the model architecture.
configuration_data (dict): The configuration data to store.
Returns:
TrainingConfigurationDB: The created or updated training configuration.
"""
existing = self.get_by_project_and_model_architecture(project_id, model_architecture_id)

if existing:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need existing check - will it not be able to detect whether to generate update/insert based on provided ids by default?

existing.configuration_data = configuration_data
self.save(existing)
return existing

new_config = TrainingConfigurationDB(
project_id=project_id,
model_architecture_id=model_architecture_id,
configuration_data=configuration_data,
)
self.save(new_config)
return new_config
Loading
Loading