Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def upgrade() -> None:
sa.Column("model_revision_id", sa.Text(), nullable=True),
sa.Column("is_running", sa.Boolean(), nullable=False),
sa.Column("data_collection_policies", sa.JSON(), nullable=False),
sa.Column("device", sa.String(length=50), nullable=False, server_default="cpu"),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False),
sa.ForeignKeyConstraint(["model_revision_id"], ["model_revisions.id"], ondelete="RESTRICT"),
Expand Down
18 changes: 16 additions & 2 deletions application/backend/app/api/routers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from fastapi.openapi.models import Example
from pydantic import ValidationError

from app.api.dependencies import get_pipeline_metrics_service, get_pipeline_service
from app.api.dependencies import get_pipeline_metrics_service, get_pipeline_service, get_system_service
from app.api.schemas import PipelineView
from app.api.validators import ProjectID
from app.models import DataCollectionPolicyAdapter, PipelineStatus
from app.schemas.metrics import PipelineMetrics
from app.services import PipelineMetricsService, PipelineService, ResourceNotFoundError
from app.services import PipelineMetricsService, PipelineService, ResourceNotFoundError, SystemService

router = APIRouter(prefix="/api/projects/{project_id}/pipeline", tags=["Pipelines"])

Expand Down Expand Up @@ -58,6 +58,11 @@
description="Remove all data collection policies of the pipeline",
value={"data_collection_policies": []},
),
"change_device": Example(
summary="Change inference device",
description="Change the device used for model inference (e.g., 'cpu', 'xpu', 'cuda', 'xpu-2', 'cuda-1')",
value={"device": "xpu"},
),
}


Expand Down Expand Up @@ -102,10 +107,19 @@ def update_pipeline(
),
],
pipeline_service: Annotated[PipelineService, Depends(get_pipeline_service)],
system_service: Annotated[SystemService, Depends(get_system_service)],
) -> PipelineView:
"""Reconfigure an existing pipeline"""
if "status" in pipeline_config:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The 'status' field cannot be changed")

if "device" in pipeline_config:
device_str = pipeline_config["device"]
if not system_service.validate_device(device_str):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail=f"Device '{device_str}' is not available on this system"
)

try:
if "data_collection_policies" in pipeline_config:
pipeline_config["data_collection_policies"] = [
Expand Down
2 changes: 2 additions & 0 deletions application/backend/app/api/schemas/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PipelineView(BaseModel):
model_revision: ModelRevision | None = Field(default=None, serialization_alias="model")
status: PipelineStatus = PipelineStatus.IDLE
data_collection_policies: list[DataCollectionPolicy] = Field(default_factory=list)
device: str = Field(default="cpu", description="Inference device (e.g., 'cpu', 'xpu', 'cuda', 'xpu-2', 'cuda-1')")

model_config = {
"json_schema_extra": {
Expand Down Expand Up @@ -52,6 +53,7 @@ class PipelineView(BaseModel):
"files_deleted": False,
},
"status": "running",
"device": "cpu",
"data_collection_policies": [
{
"type": "fixed_rate",
Expand Down
1 change: 1 addition & 0 deletions application/backend/app/db/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class PipelineDB(Base):
model_revision_id: Mapped[str | None] = mapped_column(Text, ForeignKey("model_revisions.id", ondelete="RESTRICT"))
is_running: Mapped[bool] = mapped_column(Boolean, default=False)
data_collection_policies: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
device: Mapped[str] = mapped_column(String(50), nullable=False, default="cpu")

sink = relationship("SinkDB", uselist=False, lazy="joined")
source = relationship("SourceDB", uselist=False, lazy="joined")
Expand Down
1 change: 1 addition & 0 deletions application/backend/app/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Pipeline(BaseEntity):
model_revision_id: UUID | None = None
status: PipelineStatus = PipelineStatus.IDLE
data_collection_policies: list[DataCollectionPolicy] = Field(default_factory=list)
device: str = Field(default="cpu", pattern=r"^(cpu|xpu|cuda)(-\d+)?$")

@model_validator(mode="before")
def set_status_from_is_running(cls, data: Any) -> Any:
Expand Down
1 change: 1 addition & 0 deletions application/backend/app/services/event/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class EventType(StrEnum):
SINK_CHANGED = "SINK_CHANGED"
PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED = "PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED"
PIPELINE_STATUS_CHANGED = "PIPELINE_STATUS_CHANGED"
DEVICE_CHANGED = "DEVICE_CHANGED"


class EventBus:
Expand Down
3 changes: 3 additions & 0 deletions application/backend/app/services/pipeline_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def update_pipeline(self, project_id: UUID, partial_config: dict) -> Pipeline:
model_revision_id=str(to_update.model_revision_id) if to_update.model_revision_id else None,
is_running=to_update.status.as_bool,
data_collection_policies=[obj.model_dump() for obj in to_update.data_collection_policies],
device=to_update.device,
)
pipeline_db = pipeline_repo.update(to_update_db)
updated = Pipeline.model_validate(pipeline_db)
Expand All @@ -71,6 +72,8 @@ def update_pipeline(self, project_id: UUID, partial_config: dict) -> Pipeline:
self._event_bus.emit_event(EventType.SINK_CHANGED)
if pipeline.data_collection_policies != updated.data_collection_policies:
self._event_bus.emit_event(EventType.PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED)
if pipeline.device != updated.device:
self._event_bus.emit_event(EventType.DEVICE_CHANGED)
elif pipeline.status != updated.status:
# If the pipeline is being activated or stopped
self._event_bus.emit_event(EventType.PIPELINE_STATUS_CHANGED)
Expand Down
27 changes: 27 additions & 0 deletions application/backend/app/services/system_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,30 @@ def get_devices() -> list[DeviceInfo]:
)

return devices

def validate_device(self, device_str: str) -> bool:
"""
Validate if a device string is available on the system.

Args:
device_str: Device string in format '<target>[-<index>]' (e.g., 'cpu', 'xpu', 'cuda', 'xpu-2', 'cuda-1')

Returns:
bool: True if the device is available, False otherwise
"""
# Parse device string
parts = device_str.split("-")
device_type = parts[0].lower()
device_index = int(parts[1]) if len(parts) > 1 else 0

# CPU is always available
if device_type == "cpu":
return True

# Check if desired device is among available devices
available_devices = self.get_devices()
for available_device in available_devices:
if device_type == available_device.type and device_index == (available_device.index or 0):
return True

return False
67 changes: 67 additions & 0 deletions application/backend/tests/unit/services/test_system_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,70 @@ def test_get_devices_with_multiple_devices(self, fxt_system_service: SystemServi
devices = fxt_system_service.get_devices()

assert len(devices) == 3

def test_validate_device_cpu_always_valid(self, fxt_system_service: SystemService):
"""Test that CPU device is always valid"""
assert fxt_system_service.validate_device("cpu") is True

def test_validate_device_xpu_available(self, fxt_system_service: SystemService):
"""Test validating XPU device when available"""
mock_xpu_dp = MagicMock()
mock_xpu_dp.name = "Intel XPU"
mock_xpu_dp.total_memory = 36022263808

with patch("app.services.system_service.torch") as mock_torch:
mock_torch.xpu.is_available.return_value = True
mock_torch.cuda.is_available.return_value = False
mock_torch.xpu.device_count.return_value = 2
mock_torch.xpu.get_device_properties.return_value = mock_xpu_dp

assert fxt_system_service.validate_device("xpu") is True
assert fxt_system_service.validate_device("xpu-0") is True
assert fxt_system_service.validate_device("xpu-1") is True
assert fxt_system_service.validate_device("xpu-2") is False

def test_validate_device_xpu_not_available(self, fxt_system_service: SystemService):
"""Test validating XPU device when not available"""
with patch("app.services.system_service.torch") as mock_torch:
mock_torch.xpu.is_available.return_value = False
mock_torch.cuda.is_available.return_value = False

assert fxt_system_service.validate_device("xpu") is False
assert fxt_system_service.validate_device("xpu-0") is False

def test_validate_device_cuda_available(self, fxt_system_service: SystemService):
"""Test validating CUDA device when available"""
mock_cuda_dp = MagicMock()
mock_cuda_dp.name = "NVIDIA GPU"
mock_cuda_dp.total_memory = 25769803776

with patch("app.services.system_service.torch") as mock_torch:
mock_torch.xpu.is_available.return_value = False
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.device_count.return_value = 3
mock_torch.cuda.get_device_properties.return_value = mock_cuda_dp

assert fxt_system_service.validate_device("cuda") is True
assert fxt_system_service.validate_device("cuda-0") is True
assert fxt_system_service.validate_device("cuda-1") is True
assert fxt_system_service.validate_device("cuda-2") is True
assert fxt_system_service.validate_device("cuda-3") is False

def test_validate_device_cuda_not_available(self, fxt_system_service: SystemService):
"""Test validating CUDA device when not available"""
with patch("app.services.system_service.torch") as mock_torch:
mock_torch.xpu.is_available.return_value = False
mock_torch.cuda.is_available.return_value = False

assert fxt_system_service.validate_device("cuda") is False
assert fxt_system_service.validate_device("cuda-0") is False

def test_validate_device_invalid_type(self, fxt_system_service: SystemService):
"""Test validating invalid device types"""
with patch("app.services.system_service.torch") as mock_torch:
mock_torch.xpu.is_available.return_value = False
mock_torch.cuda.is_available.return_value = False

assert fxt_system_service.validate_device("gpu") is False
assert fxt_system_service.validate_device("tpu") is False
assert fxt_system_service.validate_device("invalid") is False
Loading