Skip to content

Commit 3896f35

Browse files
A-Artemisleoll2
andauthored
Training Configurations Add/Update implementation (#4873)
Co-authored-by: Leonardo Lai <[email protected]>
1 parent bc897fd commit 3896f35

30 files changed

+2204
-161
lines changed

application/backend/app/alembic/versions/2786b50eb5a4_schema.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,19 @@ def upgrade() -> None:
156156
sa.ForeignKeyConstraint(["label_id"], ["labels.id"], ondelete="CASCADE"),
157157
sa.PrimaryKeyConstraint("dataset_item_id", "label_id"),
158158
)
159+
op.create_table(
160+
"training_configurations",
161+
sa.Column("id", sa.Text(), nullable=False),
162+
sa.Column("project_id", sa.Text(), nullable=False),
163+
sa.Column("model_architecture_id", sa.String(length=255), nullable=True),
164+
sa.Column("configuration_data", sa.JSON(), nullable=False),
165+
sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False),
166+
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False),
167+
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
168+
sa.PrimaryKeyConstraint("id"),
169+
sa.UniqueConstraint("project_id", "model_architecture_id", name="uq_project_model_config"),
170+
)
171+
159172
# ### end Alembic commands ###
160173

161174

@@ -177,5 +190,6 @@ def downgrade() -> None:
177190
op.drop_table("sources")
178191
op.drop_table("sinks")
179192
op.drop_index("idx_projects_name", table_name="projects")
193+
op.drop_table("training_configurations")
180194
op.drop_table("projects")
181195
# ### end Alembic commands ###

application/backend/app/api/dependencies.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from app.services.data_collect import DataCollector
3131
from app.services.event.event_bus import EventBus
3232
from app.services.label_service import LabelService
33+
from app.services.training_configuration_service import TrainingConfigurationService
3334
from app.webrtc.manager import WebRTCManager
3435

3536

@@ -268,3 +269,8 @@ def get_base_weights_service(data_dir: Annotated[Path, Depends(get_data_dir)]) -
268269
def get_job_queue(request: Request) -> JobQueue:
269270
"""Provides the global JobQueue instance from FastAPI application's state."""
270271
return request.app.state.job_queue
272+
273+
274+
def get_training_configuration_service(db: Annotated[Session, Depends(get_db)]) -> TrainingConfigurationService:
275+
"""Provides a TrainingConfigurationService instance for managing training configurations."""
276+
return TrainingConfigurationService(db_session=db)

application/backend/app/api/routers/projects.py

Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from typing import Annotated
77
from uuid import UUID
88

9-
from fastapi import APIRouter, Body, Depends, Query, status
9+
from fastapi import APIRouter, Body, Depends, status
1010
from fastapi.exceptions import HTTPException
1111
from fastapi.openapi.models import Example
1212
from starlette.responses import FileResponse
1313

1414
from app.api.dependencies import get_data_collector, get_label_service, get_project, get_project_id, get_project_service
15-
from app.schemas import LabelView, PatchLabels, ProjectCreate, ProjectUpdateName, ProjectView, TrainingConfiguration
15+
from app.schemas import LabelView, PatchLabels, ProjectCreate, ProjectUpdateName, ProjectView
1616
from app.services import (
1717
LabelService,
1818
ProjectService,
@@ -22,7 +22,6 @@
2222
)
2323
from app.services.data_collect import DataCollector
2424
from app.services.label_service import DuplicateLabelsError
25-
from app.supported_models.hyperparameters import Hyperparameters
2625

2726
router = APIRouter(prefix="/api/projects", tags=["Projects"])
2827

@@ -277,82 +276,3 @@ def capture_next_pipeline_frame(
277276
data_collector.collect_next_frame()
278277
except ResourceNotFoundError as e:
279278
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
280-
281-
282-
@router.get(
283-
"/{project_id}/training_configuration",
284-
response_model=TrainingConfiguration,
285-
responses={
286-
status.HTTP_200_OK: {"description": "Training configuration found"},
287-
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project ID or query parameters"},
288-
status.HTTP_404_NOT_FOUND: {"description": "Project not found"},
289-
},
290-
)
291-
def get_training_configuration(
292-
project_id: Annotated[UUID, Depends(get_project_id)],
293-
project_service: Annotated[ProjectService, Depends(get_project_service)],
294-
model_architecture_id: Annotated[str | None, Query()] = None,
295-
model_revision_id: Annotated[UUID | None, Query()] = None,
296-
) -> TrainingConfiguration:
297-
"""
298-
Get the training configuration for a project.
299-
300-
- If model_architecture_id is provided, returns configuration for that specific model architecture.
301-
- If model_revision_id is provided, returns configuration for a specific trained model.
302-
- If neither is provided, returns only general task-related configuration.
303-
Note: model_architecture_id and model_revision_id cannot be used together.
304-
305-
Args:
306-
project_id (UUID): The unique identifier of the project.
307-
project_service (ProjectService): The project service
308-
model_architecture_id (Optional[str]): The model architecture ID for specific configuration retrieval.
309-
model_revision_id (Optional[UUID]): The model revision ID for specific configuration retrieval.
310-
311-
Returns:
312-
TrainingConfiguration: The training configuration details.
313-
"""
314-
try:
315-
# TODO: Implement actual training configuration retrieval logic
316-
_ = project_id, project_service, model_architecture_id, model_revision_id
317-
return TrainingConfiguration.from_hyperparameters(hyperparameters=Hyperparameters())
318-
except ResourceNotFoundError as e:
319-
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
320-
except ValueError as e:
321-
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
322-
323-
324-
@router.patch(
325-
"/{project_id}/training_configuration",
326-
status_code=status.HTTP_204_NO_CONTENT,
327-
responses={
328-
status.HTTP_204_NO_CONTENT: {"description": "Training configuration updated successfully"},
329-
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project ID, query parameters, or request body"},
330-
status.HTTP_404_NOT_FOUND: {"description": "Project not found"},
331-
},
332-
)
333-
def update_training_configuration(
334-
project_id: Annotated[UUID, Depends(get_project_id)],
335-
training_config_update: Annotated[TrainingConfiguration, Body(description="Training configuration updates")],
336-
project_service: Annotated[ProjectService, Depends(get_project_service)],
337-
model_architecture_id: Annotated[str | None, Query()] = None,
338-
) -> None:
339-
"""
340-
Update the training configuration for a project.
341-
342-
- If model_architecture_id is provided, updates configuration for that specific model architecture.
343-
- If not provided, updates the general task-related configuration.
344-
Note: model_architecture_id cannot be used with model_revision_id for updates.
345-
346-
Args:
347-
project_id (UUID): The unique identifier of the project.
348-
training_config_update (TrainingConfiguration): The training configuration updates.
349-
project_service (ProjectService): The project service
350-
model_architecture_id (Optional[str]): The model architecture ID for specific configuration update.
351-
"""
352-
try:
353-
# TODO: Implement actual training configuration update logic
354-
_ = project_id, training_config_update, project_service, model_architecture_id
355-
except ResourceNotFoundError as e:
356-
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
357-
except ValueError as e:
358-
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Annotated
5+
from uuid import UUID
6+
7+
from fastapi import APIRouter, Depends, HTTPException, Query, status
8+
9+
from app.api.dependencies import get_project_id, get_training_configuration_service
10+
from app.api.serializers.configurable_parameters import ConfigurableParametersConverter
11+
from app.api.serializers.training_configuration import TrainingConfigurationConverter
12+
from app.services import ResourceNotFoundError
13+
from app.services.training_configuration_service import TrainingConfigurationService
14+
15+
router = APIRouter(prefix="/api/projects/{project_id}/training_configuration", tags=["Training Configuration"])
16+
17+
18+
@router.get("")
19+
def get_training_configuration(
20+
training_configuration_service: Annotated[
21+
TrainingConfigurationService, Depends(get_training_configuration_service)
22+
],
23+
project_id: Annotated[UUID, Depends(get_project_id)],
24+
model_architecture_id: Annotated[str | None, Query()] = None,
25+
model_revision_id: Annotated[UUID | None, Query()] = None,
26+
) -> dict:
27+
"""
28+
Get the training configuration for a project.
29+
30+
If model_architecture_id is provided, returns configuration for that specific model architecture.
31+
If model_revision_id is provided, returns configuration for a specific trained model.
32+
If neither is provided, returns only general task-related configuration.
33+
Note: model_architecture_id and model_revision_id cannot be used together.
34+
"""
35+
try:
36+
training_config = training_configuration_service.get_training_configuration(
37+
project_id=project_id,
38+
model_architecture_id=model_architecture_id,
39+
model_revision_id=model_revision_id,
40+
)
41+
return TrainingConfigurationConverter().training_configuration_to_rest(training_configuration=training_config)
42+
except ResourceNotFoundError as e:
43+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
44+
except ValueError as e:
45+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
46+
47+
48+
@router.patch("", status_code=status.HTTP_200_OK)
49+
def update_training_configuration(
50+
training_configuration_service: Annotated[
51+
TrainingConfigurationService, Depends(get_training_configuration_service)
52+
],
53+
project_id: Annotated[UUID, Depends(get_project_id)],
54+
training_config_update: dict,
55+
model_architecture_id: Annotated[str | None, Query()] = None,
56+
) -> dict:
57+
"""
58+
Update the training configuration for a project.
59+
60+
- If model_architecture_id is provided, updates configuration for that specific model architecture.
61+
- If not provided, updates the general task-related configuration.
62+
"""
63+
try:
64+
converted_config = ConfigurableParametersConverter.configurable_parameters_from_rest(training_config_update)
65+
updated_config = training_configuration_service.update_training_configuration(
66+
project_id=project_id,
67+
training_config_update=converted_config,
68+
model_architecture_id=model_architecture_id,
69+
)
70+
return TrainingConfigurationConverter().training_configuration_to_rest(updated_config)
71+
except ResourceNotFoundError as e:
72+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
73+
except ValueError as e:
74+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)