diff --git a/application/backend/app/api/routers/models.py b/application/backend/app/api/routers/models.py index c15831b232..5e5caa1c10 100644 --- a/application/backend/app/api/routers/models.py +++ b/application/backend/app/api/routers/models.py @@ -1,9 +1,12 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import io +import zipfile from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse from app.api.dependencies import get_model_service, get_project from app.api.schemas import ModelView, ProjectView @@ -55,6 +58,47 @@ def get_model( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) +@router.get( + "/{model_id}/binary", + responses={ + status.HTTP_200_OK: {"description": "Model weights in OpenVINO format (zip archive)"}, + status.HTTP_400_BAD_REQUEST: {"description": "Invalid project or model ID"}, + status.HTTP_404_NOT_FOUND: {"description": "Project or model not found"}, + }, +) +def download_model_binary( + project: Annotated[ProjectView, Depends(get_project)], + model_id: ModelID, + model_service: Annotated[ModelService, Depends(get_model_service)], +) -> StreamingResponse: + """Download trained model weights in OpenVINO format as a zip archive containing model.xml and model.bin files.""" + try: + # Verify the model exists and get the model directory + model_dir = model_service.get_model_files_path(project_id=project.id, model_id=model_id) + + # Create an in-memory zip file + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zip_file: + xml_file = model_dir / "model.xml" + bin_file = model_dir / "model.bin" + + zip_file.write(xml_file, arcname="model.xml") + zip_file.write(bin_file, arcname="model.bin") + + zip_buffer.seek(0) + + # Assume FP16 precision by default + filename = f"model-{model_id}-fp16.zip" + + return StreamingResponse( + zip_buffer, + media_type="application/zip", + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) + except ResourceNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + + @router.delete( "/{model_id}", status_code=status.HTTP_204_NO_CONTENT, diff --git a/application/backend/app/db_seeder.py b/application/backend/app/db_seeder.py index 3165c0b163..dada4c2750 100644 --- a/application/backend/app/db_seeder.py +++ b/application/backend/app/db_seeder.py @@ -140,7 +140,7 @@ def _create_pipeline_with_video_source( # noqa: PLR0913 project_id=project_id, sink_id=sink_id, data_collection_policies=[FixedRateDataCollectionPolicy(rate=0.1).model_dump(mode="json")], - is_running=True, + is_running=project_id == "9d6af8e8-6017-4ebe-9126-33aae739c5fa", # Running only for detection project ) pipeline.source = SourceDB( diff --git a/application/backend/app/services/model_service.py b/application/backend/app/services/model_service.py index ff4dd869c2..2e6a89388a 100644 --- a/application/backend/app/services/model_service.py +++ b/application/backend/app/services/model_service.py @@ -6,6 +6,7 @@ from pathlib import Path from uuid import UUID +from loguru import logger from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -139,3 +140,35 @@ def create_revision(self, metadata: ModelRevisionMetadata) -> None: label_schema_revision=labels_schema_rev, ) ) + + def get_model_files_path(self, project_id: UUID, model_id: UUID) -> Path: + """ + Get the directory path containing the model files (model.xml and model.bin). + + Args: + project_id (UUID): The unique identifier of the project. + model_id (UUID): The unique identifier of the model. + + Returns: + Path: The directory path containing the model files. + + Raises: + ResourceNotFoundError: If the model directory doesn't exist or required files are missing. + FileNotFoundError: If the directories or model files are not found in the expected location. + """ + model_revision = self.get_model(project_id=project_id, model_id=model_id) + if model_revision.files_deleted: + raise ResourceNotFoundError(ResourceType.MODEL, str(model_id)) + + model_dir = self._projects_dir / str(project_id) / "models" / str(model_id) + if not model_dir.exists(): + logger.error("Model directory not found: {}", model_dir) + raise FileNotFoundError + + xml_file = model_dir / "model.xml" + bin_file = model_dir / "model.bin" + if not xml_file.exists() or not bin_file.exists(): + logger.error("Model files missing in directory: {}", model_dir) + raise FileNotFoundError + + return model_dir diff --git a/application/backend/tests/unit/routers/test_models.py b/application/backend/tests/unit/routers/test_models.py index 9f6356d5b4..9c93811dbb 100644 --- a/application/backend/tests/unit/routers/test_models.py +++ b/application/backend/tests/unit/routers/test_models.py @@ -110,3 +110,51 @@ def test_delete_model_in_use(self, fxt_get_project, fxt_model_service, fxt_clien assert response.status_code == status.HTTP_409_CONFLICT assert str(err) == response.json()["detail"] fxt_model_service.delete_model.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id) + + def test_download_model_binary_success(self, fxt_get_project, fxt_model, fxt_model_service, fxt_client, tmp_path): + import zipfile + from io import BytesIO + + # Create mock model files + model_dir = tmp_path / "models" / str(fxt_model.id) + model_dir.mkdir(parents=True) + xml_content = "model data" + bin_content = b"binary model data" + (model_dir / "model.xml").write_text(xml_content) + (model_dir / "model.bin").write_bytes(bin_content) + + fxt_model_service.get_model_files_path.return_value = model_dir + + response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{fxt_model.id}/binary") + + assert response.status_code == status.HTTP_200_OK + assert response.headers["content-type"] == "application/zip" + assert "content-disposition" in response.headers + assert f"model-{fxt_model.id}-fp16.zip" in response.headers["content-disposition"] + + # Verify zip file contents + zip_data = BytesIO(response.content) + with zipfile.ZipFile(zip_data, "r") as zip_file: + assert "model.xml" in zip_file.namelist() + assert "model.bin" in zip_file.namelist() + assert zip_file.read("model.xml").decode() == xml_content + assert zip_file.read("model.bin") == bin_content + + fxt_model_service.get_model_files_path.assert_called_once_with( + project_id=fxt_get_project.id, model_id=fxt_model.id + ) + + def test_download_model_binary_not_found(self, fxt_get_project, fxt_model_service, fxt_client): + model_id = uuid4() + fxt_model_service.get_model_files_path.side_effect = ResourceNotFoundError(ResourceType.MODEL, str(model_id)) + + response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{model_id}/binary") + + assert response.status_code == status.HTTP_404_NOT_FOUND + fxt_model_service.get_model_files_path.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id) + + def test_download_model_binary_invalid_id(self, fxt_get_project, fxt_model_service, fxt_client): + response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/invalid-id/binary") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + fxt_model_service.get_model.assert_not_called()