Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions application/backend/app/api/routers/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import io
import zipfile
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse

from app.api.dependencies import get_model_service, get_project
from app.api.schemas import ModelView, ProjectView
Expand Down Expand Up @@ -55,6 +58,48 @@ def get_model(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))


@router.get(
"/{model_id}/binary",
responses={
status.HTTP_200_OK: {"description": "Model weights in OpenVINO format (zip archive)"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project or model ID"},
status.HTTP_404_NOT_FOUND: {"description": "Project or model not found"},
},
)
def download_model_binary(
project: Annotated[ProjectView, Depends(get_project)],
model_id: ModelID,
model_service: Annotated[ModelService, Depends(get_model_service)],
) -> StreamingResponse:
"""Download trained model weights in OpenVINO format as a zip archive containing model.xml and model.bin files."""
try:
# Verify the model exists and get the model directory
_ = model_service.get_model(project_id=project.id, model_id=model_id)
model_dir = model_service.get_model_files_path(project_id=project.id, model_id=model_id)

# Create an in-memory zip file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zip_file:
xml_file = model_dir / "model.xml"
bin_file = model_dir / "model.bin"

zip_file.write(xml_file, arcname="model.xml")
zip_file.write(bin_file, arcname="model.bin")

zip_buffer.seek(0)

# Assume FP16 precision by default
filename = f"model-{model_id}-fp16.zip"

return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))


@router.delete(
"/{model_id}",
status_code=status.HTTP_204_NO_CONTENT,
Expand Down
2 changes: 1 addition & 1 deletion application/backend/app/db_seeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _create_pipeline_with_video_source( # noqa: PLR0913
project_id=project_id,
sink_id=sink_id,
data_collection_policies=[FixedRateDataCollectionPolicy(rate=0.1).model_dump(mode="json")],
is_running=True,
is_running=False,
)

pipeline.source = SourceDB(
Expand Down
28 changes: 28 additions & 0 deletions application/backend/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,31 @@ def create_revision(self, metadata: ModelRevisionMetadata) -> None:
label_schema_revision=labels_schema_rev,
)
)

def get_model_files_path(self, project_id: UUID, model_id: UUID) -> Path:
"""
Get the directory path containing the model files (model.xml and model.bin).

Args:
project_id (UUID): The unique identifier of the project.
model_id (UUID): The unique identifier of the model.

Returns:
Path: The directory path containing the model files.

Raises:
ResourceNotFoundError: If the model directory doesn't exist or required files are missing.
"""
model_dir = self._projects_dir / str(project_id) / "models" / str(model_id)

if not model_dir.exists():
raise ResourceNotFoundError(ResourceType.MODEL, str(model_id))

# Verify that the required files exist
xml_file = model_dir / "model.xml"
bin_file = model_dir / "model.bin"

if not xml_file.exists() or not bin_file.exists():
raise ResourceNotFoundError(ResourceType.MODEL, str(model_id))

return model_dir
50 changes: 50 additions & 0 deletions application/backend/tests/unit/routers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,53 @@ def test_delete_model_in_use(self, fxt_get_project, fxt_model_service, fxt_clien
assert response.status_code == status.HTTP_409_CONFLICT
assert str(err) == response.json()["detail"]
fxt_model_service.delete_model.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id)

def test_download_model_binary_success(self, fxt_get_project, fxt_model, fxt_model_service, fxt_client, tmp_path):
import zipfile
from io import BytesIO

# Create mock model files
model_dir = tmp_path / "models" / str(fxt_model.id)
model_dir.mkdir(parents=True)
xml_content = "<xml>model data</xml>"
bin_content = b"binary model data"
(model_dir / "model.xml").write_text(xml_content)
(model_dir / "model.bin").write_bytes(bin_content)

fxt_model_service.get_model.return_value = fxt_model
fxt_model_service.get_model_files_path.return_value = model_dir

response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{fxt_model.id}/binary")

assert response.status_code == status.HTTP_200_OK
assert response.headers["content-type"] == "application/zip"
assert "content-disposition" in response.headers
assert f"model-{fxt_model.id}-fp16.zip" in response.headers["content-disposition"]

# Verify zip file contents
zip_data = BytesIO(response.content)
with zipfile.ZipFile(zip_data, "r") as zip_file:
assert "model.xml" in zip_file.namelist()
assert "model.bin" in zip_file.namelist()
assert zip_file.read("model.xml").decode() == xml_content
assert zip_file.read("model.bin") == bin_content

fxt_model_service.get_model.assert_called_once_with(project_id=fxt_get_project.id, model_id=fxt_model.id)
fxt_model_service.get_model_files_path.assert_called_once_with(
project_id=fxt_get_project.id, model_id=fxt_model.id
)

def test_download_model_binary_not_found(self, fxt_get_project, fxt_model_service, fxt_client):
model_id = uuid4()
fxt_model_service.get_model.side_effect = ResourceNotFoundError(ResourceType.MODEL, str(model_id))

response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{model_id}/binary")

assert response.status_code == status.HTTP_404_NOT_FOUND
fxt_model_service.get_model.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id)

def test_download_model_binary_invalid_id(self, fxt_get_project, fxt_model_service, fxt_client):
response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/invalid-id/binary")

assert response.status_code == status.HTTP_400_BAD_REQUEST
fxt_model_service.get_model.assert_not_called()
Loading