diff --git a/application/backend/app/alembic/versions/2786b50eb5a4_schema.py b/application/backend/app/alembic/versions/2786b50eb5a4_schema.py index 0186f58610..7d5bb35a1d 100644 --- a/application/backend/app/alembic/versions/2786b50eb5a4_schema.py +++ b/application/backend/app/alembic/versions/2786b50eb5a4_schema.py @@ -137,6 +137,7 @@ def upgrade() -> None: sa.Column("model_revision_id", sa.Text(), nullable=True), sa.Column("is_running", sa.Boolean(), nullable=False), sa.Column("data_collection_policies", sa.JSON(), nullable=False), + sa.Column("device", sa.String(length=50), nullable=False, server_default="cpu"), sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), sa.ForeignKeyConstraint(["model_revision_id"], ["model_revisions.id"], ondelete="RESTRICT"), diff --git a/application/backend/app/api/routers/pipelines.py b/application/backend/app/api/routers/pipelines.py index e998e76e24..51e51035a4 100644 --- a/application/backend/app/api/routers/pipelines.py +++ b/application/backend/app/api/routers/pipelines.py @@ -10,12 +10,12 @@ from fastapi.openapi.models import Example from pydantic import ValidationError -from app.api.dependencies import get_pipeline_metrics_service, get_pipeline_service +from app.api.dependencies import get_pipeline_metrics_service, get_pipeline_service, get_system_service from app.api.schemas import PipelineView from app.api.validators import ProjectID from app.models import DataCollectionPolicyAdapter, PipelineStatus from app.schemas.metrics import PipelineMetrics -from app.services import PipelineMetricsService, PipelineService, ResourceNotFoundError +from app.services import PipelineMetricsService, PipelineService, ResourceNotFoundError, SystemService router = APIRouter(prefix="/api/projects/{project_id}/pipeline", tags=["Pipelines"]) @@ -58,6 +58,11 @@ description="Remove all data collection policies of the pipeline", value={"data_collection_policies": []}, ), + "change_device": Example( + summary="Change inference device", + description="Change the device used for model inference (e.g., 'cpu', 'xpu', 'cuda', 'xpu-2', 'cuda-1')", + value={"device": "xpu"}, + ), } @@ -102,10 +107,19 @@ def update_pipeline( ), ], pipeline_service: Annotated[PipelineService, Depends(get_pipeline_service)], + system_service: Annotated[SystemService, Depends(get_system_service)], ) -> PipelineView: """Reconfigure an existing pipeline""" if "status" in pipeline_config: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The 'status' field cannot be changed") + + if "device" in pipeline_config: + device_str = pipeline_config["device"] + if not system_service.validate_device(device_str): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=f"Device '{device_str}' is not available on this system" + ) + try: if "data_collection_policies" in pipeline_config: pipeline_config["data_collection_policies"] = [ diff --git a/application/backend/app/api/schemas/pipeline.py b/application/backend/app/api/schemas/pipeline.py index 83b39b428b..5dff9d121e 100644 --- a/application/backend/app/api/schemas/pipeline.py +++ b/application/backend/app/api/schemas/pipeline.py @@ -18,6 +18,7 @@ class PipelineView(BaseModel): model_revision: ModelRevision | None = Field(default=None, serialization_alias="model") status: PipelineStatus = PipelineStatus.IDLE data_collection_policies: list[DataCollectionPolicy] = Field(default_factory=list) + device: str = Field(default="cpu", description="Inference device (e.g., 'cpu', 'xpu', 'cuda', 'xpu-2', 'cuda-1')") model_config = { "json_schema_extra": { @@ -52,6 +53,7 @@ class PipelineView(BaseModel): "files_deleted": False, }, "status": "running", + "device": "cpu", "data_collection_policies": [ { "type": "fixed_rate", diff --git a/application/backend/app/db/schema.py b/application/backend/app/db/schema.py index 6c84aa2dc9..d6e0f806de 100644 --- a/application/backend/app/db/schema.py +++ b/application/backend/app/db/schema.py @@ -48,6 +48,7 @@ class PipelineDB(Base): model_revision_id: Mapped[str | None] = mapped_column(Text, ForeignKey("model_revisions.id", ondelete="RESTRICT")) is_running: Mapped[bool] = mapped_column(Boolean, default=False) data_collection_policies: Mapped[list] = mapped_column(JSON, nullable=False, default=list) + device: Mapped[str] = mapped_column(String(50), nullable=False, default="cpu") sink = relationship("SinkDB", uselist=False, lazy="joined") source = relationship("SourceDB", uselist=False, lazy="joined") diff --git a/application/backend/app/models/pipeline.py b/application/backend/app/models/pipeline.py index adc65e47f7..5d2cddc368 100644 --- a/application/backend/app/models/pipeline.py +++ b/application/backend/app/models/pipeline.py @@ -58,6 +58,7 @@ class Pipeline(BaseEntity): model_revision_id: UUID | None = None status: PipelineStatus = PipelineStatus.IDLE data_collection_policies: list[DataCollectionPolicy] = Field(default_factory=list) + device: str = Field(default="cpu", pattern=r"^(cpu|xpu|cuda)(-\d+)?$") @model_validator(mode="before") def set_status_from_is_running(cls, data: Any) -> Any: diff --git a/application/backend/app/services/event/event_bus.py b/application/backend/app/services/event/event_bus.py index 1e758ba16f..2983ea85c2 100644 --- a/application/backend/app/services/event/event_bus.py +++ b/application/backend/app/services/event/event_bus.py @@ -12,6 +12,7 @@ class EventType(StrEnum): SINK_CHANGED = "SINK_CHANGED" PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED = "PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED" PIPELINE_STATUS_CHANGED = "PIPELINE_STATUS_CHANGED" + INFERENCE_DEVICE_CHANGED = "INFERENCE_DEVICE_CHANGED" class EventBus: diff --git a/application/backend/app/services/pipeline_service.py b/application/backend/app/services/pipeline_service.py index f4f5a4d4f6..7c5f4257e0 100644 --- a/application/backend/app/services/pipeline_service.py +++ b/application/backend/app/services/pipeline_service.py @@ -60,6 +60,7 @@ def update_pipeline(self, project_id: UUID, partial_config: dict) -> Pipeline: model_revision_id=str(to_update.model_revision_id) if to_update.model_revision_id else None, is_running=to_update.status.as_bool, data_collection_policies=[obj.model_dump() for obj in to_update.data_collection_policies], + device=to_update.device, ) pipeline_db = pipeline_repo.update(to_update_db) updated = Pipeline.model_validate(pipeline_db) @@ -71,6 +72,8 @@ def update_pipeline(self, project_id: UUID, partial_config: dict) -> Pipeline: self._event_bus.emit_event(EventType.SINK_CHANGED) if pipeline.data_collection_policies != updated.data_collection_policies: self._event_bus.emit_event(EventType.PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED) + if pipeline.device != updated.device: + self._event_bus.emit_event(EventType.INFERENCE_DEVICE_CHANGED) elif pipeline.status != updated.status: # If the pipeline is being activated or stopped self._event_bus.emit_event(EventType.PIPELINE_STATUS_CHANGED) diff --git a/application/backend/app/services/system_service.py b/application/backend/app/services/system_service.py index feb67e414d..1e48e92baf 100644 --- a/application/backend/app/services/system_service.py +++ b/application/backend/app/services/system_service.py @@ -1,11 +1,15 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import re + import psutil import torch from app.schemas.system import DeviceInfo, DeviceType +DEVICE_PATTERN = re.compile(r"^(cpu|xpu|cuda)(-(\d+))?$") + class SystemService: """Service to get system information""" @@ -70,3 +74,46 @@ def get_devices() -> list[DeviceInfo]: ) return devices + + def validate_device(self, device_str: str) -> bool: + """ + Validate if a device string is available on the system. + + Args: + device_str: Device string in format '[-]' (e.g., 'cpu', 'xpu', 'cuda', 'xpu-2', 'cuda-1') + + Returns: + bool: True if the device is available, False otherwise + """ + device_type, device_index = self._parse_device(device_str) + + # CPU is always available + if device_type == "cpu": + return True + + # Check if desired device is among available devices + available_devices = self.get_devices() + for available_device in available_devices: + if device_type == available_device.type and device_index == (available_device.index or 0): + return True + + return False + + @staticmethod + def _parse_device(device_str: str) -> tuple[str, int]: + """ + Parse device string into type and index + + Args: + device_str: Device string in format '[-]' (e.g., 'cpu', 'xpu', 'cuda', 'xpu-2', 'cuda-1') + + Returns: + tuple[str, int]: Device type and index + """ + m = DEVICE_PATTERN.match(device_str.lower()) + if not m: + raise ValueError(f"Invalid device string: {device_str}") + + device_type, _, device_index = m.groups() + device_index = int(device_index) if device_index is not None else 0 + return device_type.lower(), device_index diff --git a/application/backend/tests/unit/services/test_system_service.py b/application/backend/tests/unit/services/test_system_service.py index b940792e4b..3b921a48de 100644 --- a/application/backend/tests/unit/services/test_system_service.py +++ b/application/backend/tests/unit/services/test_system_service.py @@ -111,3 +111,76 @@ def test_get_devices_with_multiple_devices(self, fxt_system_service: SystemServi devices = fxt_system_service.get_devices() assert len(devices) == 3 + + def test_validate_device_cpu_always_valid(self, fxt_system_service: SystemService): + """Test that CPU device is always valid""" + assert fxt_system_service.validate_device("cpu") is True + + def test_validate_device_xpu_available(self, fxt_system_service: SystemService): + """Test validating XPU device when available""" + mock_xpu_dp = MagicMock() + mock_xpu_dp.name = "Intel XPU" + mock_xpu_dp.total_memory = 36022263808 + + with patch("app.services.system_service.torch") as mock_torch: + mock_torch.xpu.is_available.return_value = True + mock_torch.cuda.is_available.return_value = False + mock_torch.xpu.device_count.return_value = 2 + mock_torch.xpu.get_device_properties.return_value = mock_xpu_dp + + assert fxt_system_service.validate_device("xpu") is True + assert fxt_system_service.validate_device("xpu-0") is True + assert fxt_system_service.validate_device("xpu-1") is True + assert fxt_system_service.validate_device("xpu-2") is False + + def test_validate_device_xpu_not_available(self, fxt_system_service: SystemService): + """Test validating XPU device when not available""" + with patch("app.services.system_service.torch") as mock_torch: + mock_torch.xpu.is_available.return_value = False + mock_torch.cuda.is_available.return_value = False + + assert fxt_system_service.validate_device("xpu") is False + assert fxt_system_service.validate_device("xpu-0") is False + + def test_validate_device_cuda_available(self, fxt_system_service: SystemService): + """Test validating CUDA device when available""" + mock_cuda_dp = MagicMock() + mock_cuda_dp.name = "NVIDIA GPU" + mock_cuda_dp.total_memory = 25769803776 + + with patch("app.services.system_service.torch") as mock_torch: + mock_torch.xpu.is_available.return_value = False + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.device_count.return_value = 3 + mock_torch.cuda.get_device_properties.return_value = mock_cuda_dp + + assert fxt_system_service.validate_device("cuda") is True + assert fxt_system_service.validate_device("cuda-0") is True + assert fxt_system_service.validate_device("cuda-1") is True + assert fxt_system_service.validate_device("cuda-2") is True + assert fxt_system_service.validate_device("cuda-3") is False + + def test_validate_device_cuda_not_available(self, fxt_system_service: SystemService): + """Test validating CUDA device when not available""" + with patch("app.services.system_service.torch") as mock_torch: + mock_torch.xpu.is_available.return_value = False + mock_torch.cuda.is_available.return_value = False + + assert fxt_system_service.validate_device("cuda") is False + assert fxt_system_service.validate_device("cuda-0") is False + + def test_validate_device_invalid_type(self, fxt_system_service: SystemService): + """Test validating invalid device types""" + with patch("app.services.system_service.torch") as mock_torch, pytest.raises(ValueError): + mock_torch.xpu.is_available.return_value = False + mock_torch.cuda.is_available.return_value = False + + assert fxt_system_service.validate_device("cpu-cpu") is False + assert fxt_system_service.validate_device("cpu--1") is False + assert fxt_system_service.validate_device("cpu-") is False + assert fxt_system_service.validate_device("cpu-0.9") is False + assert fxt_system_service.validate_device("1") is False + assert fxt_system_service.validate_device("-1") is False + assert fxt_system_service.validate_device("gpu") is False + assert fxt_system_service.validate_device("tpu") is False + assert fxt_system_service.validate_device("invalid") is False