Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions application/backend/app/api/routers/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import io
import zipfile
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse

from app.api.dependencies import get_model_service, get_project
from app.api.schemas import ModelView, ProjectView
Expand Down Expand Up @@ -55,6 +58,47 @@ def get_model(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))


@router.get(
"/{model_id}/binary",
responses={
status.HTTP_200_OK: {"description": "Model weights in OpenVINO format (zip archive)"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project or model ID"},
status.HTTP_404_NOT_FOUND: {"description": "Project or model not found"},
},
)
def download_model_binary(
project: Annotated[ProjectView, Depends(get_project)],
model_id: ModelID,
model_service: Annotated[ModelService, Depends(get_model_service)],
) -> StreamingResponse:
"""Download trained model weights in OpenVINO format as a zip archive containing model.xml and model.bin files."""
try:
# Verify the model exists and get the model directory
model_dir = model_service.get_model_files_path(project_id=project.id, model_id=model_id)

# Create an in-memory zip file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zip_file:
xml_file = model_dir / "model.xml"
bin_file = model_dir / "model.bin"

zip_file.write(xml_file, arcname="model.xml")
zip_file.write(bin_file, arcname="model.bin")

zip_buffer.seek(0)

# Assume FP16 precision by default
filename = f"model-{model_id}-fp16.zip"

return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))


@router.delete(
"/{model_id}",
status_code=status.HTTP_204_NO_CONTENT,
Expand Down
2 changes: 1 addition & 1 deletion application/backend/app/db_seeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _create_pipeline_with_video_source( # noqa: PLR0913
project_id=project_id,
sink_id=sink_id,
data_collection_policies=[FixedRateDataCollectionPolicy(rate=0.1).model_dump(mode="json")],
is_running=True,
is_running=project_id == "9d6af8e8-6017-4ebe-9126-33aae739c5fa", # Running only for detection project
)

pipeline.source = SourceDB(
Expand Down
33 changes: 33 additions & 0 deletions application/backend/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from uuid import UUID

from loguru import logger
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -139,3 +140,35 @@ def create_revision(self, metadata: ModelRevisionMetadata) -> None:
label_schema_revision=labels_schema_rev,
)
)

def get_model_files_path(self, project_id: UUID, model_id: UUID) -> Path:
"""
Get the directory path containing the model files (model.xml and model.bin).

Args:
project_id (UUID): The unique identifier of the project.
model_id (UUID): The unique identifier of the model.

Returns:
Path: The directory path containing the model files.

Raises:
ResourceNotFoundError: If the model directory doesn't exist or required files are missing.
FileNotFoundError: If the directories or model files are not found in the expected location.
"""
model_revision = self.get_model(project_id=project_id, model_id=model_id)
if model_revision.files_deleted:
raise ResourceNotFoundError(ResourceType.MODEL, str(model_id))

model_dir = self._projects_dir / str(project_id) / "models" / str(model_id)
if not model_dir.exists():
logger.error("Model directory not found: {}", model_dir)
raise FileNotFoundError

xml_file = model_dir / "model.xml"
bin_file = model_dir / "model.bin"
if not xml_file.exists() or not bin_file.exists():
logger.error("Model files missing in directory: {}", model_dir)
raise FileNotFoundError

return model_dir
48 changes: 48 additions & 0 deletions application/backend/tests/unit/routers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,51 @@ def test_delete_model_in_use(self, fxt_get_project, fxt_model_service, fxt_clien
assert response.status_code == status.HTTP_409_CONFLICT
assert str(err) == response.json()["detail"]
fxt_model_service.delete_model.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id)

def test_download_model_binary_success(self, fxt_get_project, fxt_model, fxt_model_service, fxt_client, tmp_path):
import zipfile
from io import BytesIO

# Create mock model files
model_dir = tmp_path / "models" / str(fxt_model.id)
model_dir.mkdir(parents=True)
xml_content = "<xml>model data</xml>"
bin_content = b"binary model data"
(model_dir / "model.xml").write_text(xml_content)
(model_dir / "model.bin").write_bytes(bin_content)

fxt_model_service.get_model_files_path.return_value = model_dir

response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{fxt_model.id}/binary")

assert response.status_code == status.HTTP_200_OK
assert response.headers["content-type"] == "application/zip"
assert "content-disposition" in response.headers
assert f"model-{fxt_model.id}-fp16.zip" in response.headers["content-disposition"]

# Verify zip file contents
zip_data = BytesIO(response.content)
with zipfile.ZipFile(zip_data, "r") as zip_file:
assert "model.xml" in zip_file.namelist()
assert "model.bin" in zip_file.namelist()
assert zip_file.read("model.xml").decode() == xml_content
assert zip_file.read("model.bin") == bin_content

fxt_model_service.get_model_files_path.assert_called_once_with(
project_id=fxt_get_project.id, model_id=fxt_model.id
)

def test_download_model_binary_not_found(self, fxt_get_project, fxt_model_service, fxt_client):
model_id = uuid4()
fxt_model_service.get_model_files_path.side_effect = ResourceNotFoundError(ResourceType.MODEL, str(model_id))

response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{model_id}/binary")

assert response.status_code == status.HTTP_404_NOT_FOUND
fxt_model_service.get_model_files_path.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id)

def test_download_model_binary_invalid_id(self, fxt_get_project, fxt_model_service, fxt_client):
response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/invalid-id/binary")

assert response.status_code == status.HTTP_400_BAD_REQUEST
fxt_model_service.get_model.assert_not_called()
Loading