Skip to content

Commit 2d11d47

Browse files
authored
Download model weights (#5059)
1 parent 62e7923 commit 2d11d47

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

application/backend/app/api/routers/models.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# Copyright (C) 2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import io
5+
import zipfile
46
from typing import Annotated
57

68
from fastapi import APIRouter, Depends, HTTPException, status
9+
from fastapi.responses import StreamingResponse
710

811
from app.api.dependencies import get_model_service, get_project
912
from app.api.schemas import ModelView, ProjectView
@@ -55,6 +58,47 @@ def get_model(
5558
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
5659

5760

61+
@router.get(
62+
"/{model_id}/binary",
63+
responses={
64+
status.HTTP_200_OK: {"description": "Model weights in OpenVINO format (zip archive)"},
65+
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project or model ID"},
66+
status.HTTP_404_NOT_FOUND: {"description": "Project or model not found"},
67+
},
68+
)
69+
def download_model_binary(
70+
project: Annotated[ProjectView, Depends(get_project)],
71+
model_id: ModelID,
72+
model_service: Annotated[ModelService, Depends(get_model_service)],
73+
) -> StreamingResponse:
74+
"""Download trained model weights in OpenVINO format as a zip archive containing model.xml and model.bin files."""
75+
try:
76+
# Verify the model exists and get the model directory
77+
model_dir = model_service.get_model_files_path(project_id=project.id, model_id=model_id)
78+
79+
# Create an in-memory zip file
80+
zip_buffer = io.BytesIO()
81+
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zip_file:
82+
xml_file = model_dir / "model.xml"
83+
bin_file = model_dir / "model.bin"
84+
85+
zip_file.write(xml_file, arcname="model.xml")
86+
zip_file.write(bin_file, arcname="model.bin")
87+
88+
zip_buffer.seek(0)
89+
90+
# Assume FP16 precision by default
91+
filename = f"model-{model_id}-fp16.zip"
92+
93+
return StreamingResponse(
94+
zip_buffer,
95+
media_type="application/zip",
96+
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
97+
)
98+
except ResourceNotFoundError as e:
99+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
100+
101+
58102
@router.delete(
59103
"/{model_id}",
60104
status_code=status.HTTP_204_NO_CONTENT,

application/backend/app/db_seeder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _create_pipeline_with_video_source( # noqa: PLR0913
140140
project_id=project_id,
141141
sink_id=sink_id,
142142
data_collection_policies=[FixedRateDataCollectionPolicy(rate=0.1).model_dump(mode="json")],
143-
is_running=True,
143+
is_running=project_id == "9d6af8e8-6017-4ebe-9126-33aae739c5fa", # Running only for detection project
144144
)
145145

146146
pipeline.source = SourceDB(

application/backend/app/services/model_service.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
from uuid import UUID
88

9+
from loguru import logger
910
from sqlalchemy.exc import IntegrityError
1011
from sqlalchemy.orm import Session
1112

@@ -139,3 +140,35 @@ def create_revision(self, metadata: ModelRevisionMetadata) -> None:
139140
label_schema_revision=labels_schema_rev,
140141
)
141142
)
143+
144+
def get_model_files_path(self, project_id: UUID, model_id: UUID) -> Path:
145+
"""
146+
Get the directory path containing the model files (model.xml and model.bin).
147+
148+
Args:
149+
project_id (UUID): The unique identifier of the project.
150+
model_id (UUID): The unique identifier of the model.
151+
152+
Returns:
153+
Path: The directory path containing the model files.
154+
155+
Raises:
156+
ResourceNotFoundError: If the model directory doesn't exist or required files are missing.
157+
FileNotFoundError: If the directories or model files are not found in the expected location.
158+
"""
159+
model_revision = self.get_model(project_id=project_id, model_id=model_id)
160+
if model_revision.files_deleted:
161+
raise ResourceNotFoundError(ResourceType.MODEL, str(model_id))
162+
163+
model_dir = self._projects_dir / str(project_id) / "models" / str(model_id)
164+
if not model_dir.exists():
165+
logger.error("Model directory not found: {}", model_dir)
166+
raise FileNotFoundError
167+
168+
xml_file = model_dir / "model.xml"
169+
bin_file = model_dir / "model.bin"
170+
if not xml_file.exists() or not bin_file.exists():
171+
logger.error("Model files missing in directory: {}", model_dir)
172+
raise FileNotFoundError
173+
174+
return model_dir

application/backend/tests/unit/routers/test_models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,51 @@ def test_delete_model_in_use(self, fxt_get_project, fxt_model_service, fxt_clien
110110
assert response.status_code == status.HTTP_409_CONFLICT
111111
assert str(err) == response.json()["detail"]
112112
fxt_model_service.delete_model.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id)
113+
114+
def test_download_model_binary_success(self, fxt_get_project, fxt_model, fxt_model_service, fxt_client, tmp_path):
115+
import zipfile
116+
from io import BytesIO
117+
118+
# Create mock model files
119+
model_dir = tmp_path / "models" / str(fxt_model.id)
120+
model_dir.mkdir(parents=True)
121+
xml_content = "<xml>model data</xml>"
122+
bin_content = b"binary model data"
123+
(model_dir / "model.xml").write_text(xml_content)
124+
(model_dir / "model.bin").write_bytes(bin_content)
125+
126+
fxt_model_service.get_model_files_path.return_value = model_dir
127+
128+
response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{fxt_model.id}/binary")
129+
130+
assert response.status_code == status.HTTP_200_OK
131+
assert response.headers["content-type"] == "application/zip"
132+
assert "content-disposition" in response.headers
133+
assert f"model-{fxt_model.id}-fp16.zip" in response.headers["content-disposition"]
134+
135+
# Verify zip file contents
136+
zip_data = BytesIO(response.content)
137+
with zipfile.ZipFile(zip_data, "r") as zip_file:
138+
assert "model.xml" in zip_file.namelist()
139+
assert "model.bin" in zip_file.namelist()
140+
assert zip_file.read("model.xml").decode() == xml_content
141+
assert zip_file.read("model.bin") == bin_content
142+
143+
fxt_model_service.get_model_files_path.assert_called_once_with(
144+
project_id=fxt_get_project.id, model_id=fxt_model.id
145+
)
146+
147+
def test_download_model_binary_not_found(self, fxt_get_project, fxt_model_service, fxt_client):
148+
model_id = uuid4()
149+
fxt_model_service.get_model_files_path.side_effect = ResourceNotFoundError(ResourceType.MODEL, str(model_id))
150+
151+
response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/{model_id}/binary")
152+
153+
assert response.status_code == status.HTTP_404_NOT_FOUND
154+
fxt_model_service.get_model_files_path.assert_called_once_with(project_id=fxt_get_project.id, model_id=model_id)
155+
156+
def test_download_model_binary_invalid_id(self, fxt_get_project, fxt_model_service, fxt_client):
157+
response = fxt_client.get(f"/api/projects/{fxt_get_project.id}/models/invalid-id/binary")
158+
159+
assert response.status_code == status.HTTP_400_BAD_REQUEST
160+
fxt_model_service.get_model.assert_not_called()

0 commit comments

Comments
 (0)