diff --git a/src/flyte/_internal/runtime/resources_serde.py b/src/flyte/_internal/runtime/resources_serde.py index a4c224936..cb9842849 100644 --- a/src/flyte/_internal/runtime/resources_serde.py +++ b/src/flyte/_internal/runtime/resources_serde.py @@ -24,6 +24,9 @@ "V6E": "tpu-v6e-slice", } +# Default prefix for GPU partition (MIG) resources +DEFAULT_GPU_PARTITION_RESOURCE_PREFIX = "nvidia.com/mig" + _DeviceClassToProto: Dict[DeviceClass, "tasks_pb2.GPUAccelerator.DeviceClass"] = { "GPU": tasks_pb2.GPUAccelerator.NVIDIA_GPU, "TPU": tasks_pb2.GPUAccelerator.GOOGLE_TPU, @@ -85,8 +88,10 @@ def _get_disk_resource_entry(disk: str) -> tasks_pb2.Resources.ResourceEntry: def get_proto_extended_resources(resources: Resources | None) -> Optional[tasks_pb2.ExtendedResources]: """ - TODO Implement partitioning logic string handling for GPU - :param resources: + Get extended resources (GPU accelerator, shared memory) for the task. + + :param resources: Resources object containing GPU and shared memory configuration + :return: ExtendedResources protobuf or None if no extended resources are configured """ if resources is None: return None @@ -128,7 +133,11 @@ def _convert_resources_to_resource_entries( if resources.gpu is not None: device = resources.get_device() if device is not None: - request_entries.append(_get_gpu_resource_entry(device.quantity)) + if device.partition is None: + # Only add standard GPU resource if NO partition + # Partitioned GPUs (MIG) are handled separately at Pod spec creation + if device.partition is None: + request_entries.append(_get_gpu_resource_entry(device.quantity)) if resources.disk is not None: request_entries.append(_get_disk_resource_entry(resources.disk)) diff --git a/src/flyte/_internal/runtime/task_serde.py b/src/flyte/_internal/runtime/task_serde.py index 1a7b3a588..d5852ae17 100644 --- a/src/flyte/_internal/runtime/task_serde.py +++ b/src/flyte/_internal/runtime/task_serde.py @@ -6,7 +6,7 @@ import copy import typing from datetime import timedelta -from typing import Optional, cast +from typing import Dict, Optional, cast from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2 from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2 @@ -23,7 +23,7 @@ from ... import ReusePolicy from ..._retry import RetryStrategy from ..._timeout import TimeoutType, timeout_from_request -from .resources_serde import get_proto_extended_resources, get_proto_resources +from .resources_serde import DEFAULT_GPU_PARTITION_RESOURCE_PREFIX, get_proto_extended_resources, get_proto_resources from .reuse import add_reusable from .types_serde import transform_native_to_typed_interface @@ -124,7 +124,8 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) extra_config: typing.Dict[str, str] = {} if task.pod_template and not isinstance(task.pod_template, str): - pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template) + extended_resources = get_proto_extended_resources(task.resources) + pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template, task, extended_resources) extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name container = None else: @@ -264,10 +265,61 @@ def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str: return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") -def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]: +def _get_mig_resources_from_extended_resources( + extended_resources: Optional[tasks_pb2.ExtendedResources], + device_quantity: Optional[int] = None, + mig_resource_prefix: Optional[str] = None, +) -> Dict[str, str]: + """ + Generate MIG-specific resources from GPUAccelerator partition info. + + When a GPU has a partition_size specified, generate a custom resource name + for that partition instead of using the generic GPU resource. This resource + will be added to both requests and limits in the Pod spec. + + Example: + If partition="1g.10gb" and prefix="nvidia.com/mig", returns: + {"nvidia.com/mig-1g.10gb": "1"} + + :param extended_resources: The extended resources containing GPU accelerator info + :param device_quantity: The quantity of GPUs/partitions requested + :param mig_resource_prefix: Custom MIG resource prefix (defaults to "nvidia.com/mig"). + Can be overridden via Resources.gpu_partition_resource_prefix + :return: Dict mapping MIG resource name to quantity (e.g., {"nvidia.com/mig-1g.10gb": "1"}) + """ + mig_resources = Dict[str, str] = {} + + if extended_resources is None or not extended_resources.gpu_accelerator: + return mig_resources + + gpu_accel = extended_resources.gpu_accelerator + partition = gpu_accel.partition_size + + if not partition: + return mig_resources + + # Default to "nvidia.com/mig" if not specified + prefix = mig_resource_prefix if mig_resource_prefix is not None else DEFAULT_GPU_PARTITION_RESOURCE_PREFIX + resource_name = f"{prefix}-{partition}" + + quantity = device_quantity if device_quantity is not None else 1 + mig_resources[resource_name] = str(quantity) + + return mig_resources + + +def _get_k8s_pod( + primary_container: tasks_pb2.Container, + pod_template: PodTemplate, + task_template: TaskTemplate, + extended_resources: Optional[tasks_pb2.ExtendedResources], +) -> Optional[tasks_pb2.K8sPod]: """ Get the K8sPod representation of the task template. - :param task: The task to convert. + :param primary_container: The primary container to use. + :param pod_template: The pod template to use. + :param task_template: The task template containing resources configuration. + :param extended_resources: The extended resources (GPU accelerator, shared memory). :return: The K8sPod representation of the task template. """ from kubernetes.client import ApiClient, V1PodSpec @@ -307,8 +359,18 @@ def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTempla for resource in primary_container.resources.requests: requests[_sanitize_resource_name(resource)] = resource.value + # Add MIG resources if GPU partitions are specified + mig_prefix = task_template.resources.gpu_partition_resource_prefix if task_template.resources else None + # Get device quantity from resources + device = task_template.resources.get_device() if task_template.resources else None + device_quantity = device.quantity if device else None + mig_resources = _get_mig_resources_from_extended_resources(extended_resources, device_quantity, mig_prefix) + # Add MIG resources to both requests and limits + requests.update(mig_resources) + limits.update(mig_resources) + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) - if len(limits) > 0 or len(requests) > 0: + if len(limits) > 0 or len(requests) > 0 or mig_resources: # Important! Only copy over resource requirements if they are non-empty. container.resources = resource_requirements diff --git a/src/flyte/_resources.py b/src/flyte/_resources.py index 8e18f6227..3396fccff 100644 --- a/src/flyte/_resources.py +++ b/src/flyte/_resources.py @@ -366,6 +366,9 @@ def my_task() -> int: :param disk: The amount of disk to allocate to the task. This is a string of the form "10GiB". :param shm: The amount of shared memory to allocate to the task. This is a string of the form "10GiB" or "auto". If "auto", then the shared memory will be set to max amount of shared memory available on the node. + :param gpu_partition_resource_prefix: Optional override for the GPU partition (MIG) resource name prefix. + Defaults to "nvidia.com/mig" when a GPU partition is specified. Only applies when GPU partition is requested. + For example, set to "custom.company.com/mig" to override the resource name prefix. """ cpu: Union[CPUBaseType, Tuple[CPUBaseType, CPUBaseType], None] = None @@ -373,6 +376,7 @@ def my_task() -> int: gpu: Union[Accelerators, int, Device, None] = None disk: Union[str, None] = None shm: Union[str, Literal["auto"], None] = None + gpu_partition_resource_prefix: Optional[str] = None def __post_init__(self): if isinstance(self.cpu, tuple): diff --git a/src/flyte/app/_runtime/app_serde.py b/src/flyte/app/_runtime/app_serde.py index 51dc8f2db..340dface6 100644 --- a/src/flyte/app/_runtime/app_serde.py +++ b/src/flyte/app/_runtime/app_serde.py @@ -9,7 +9,7 @@ from copy import deepcopy from dataclasses import replace -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from flyteidl2.app import app_definition_pb2 from flyteidl2.common import runtime_version_pb2 @@ -83,6 +83,43 @@ def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str: return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") +def _get_mig_resources_from_extended_resources( + extended_resources: Optional[tasks_pb2.ExtendedResources], + device_quantity: Optional[int] = None, + mig_resource_prefix: Optional[str] = None, +) -> Dict[str, str]: + """ + Generate MIG-specific resources from GPUAccelerator partition info. + + When a GPU has a partition_size specified, generate a custom resource name + for that partition instead of using the generic GPU resource. + + :param extended_resources: The extended resources containing GPU accelerator info + :param mig_resource_prefix: Custom MIG resource prefix (defaults to "nvidia.com/mig") + :param device_quantity: The quantity of GPUs/partitions requested + :return: Dict mapping MIG resource name to quantity (e.g., {"nvidia.com/mig-1g. 5gb": "1"}) + """ + mig_resources: Dict[str, str] = {} + + if extended_resources is None or not extended_resources.gpu_accelerator: + return mig_resources + + gpu_accel = extended_resources.gpu_accelerator + partition = gpu_accel.partition_size + + if not partition: + return mig_resources + + # Default to "nvidia.com/mig" if not specified + prefix = mig_resource_prefix if mig_resource_prefix is not None else "nvidia.com/mig" + resource_name = f"{prefix}-{partition}" + + quantity = device_quantity if device_quantity is not None else 1 + mig_resources[resource_name] = str(quantity) + + return mig_resources + + def _serialized_pod_spec( app_env: AppEnvironment, pod_template: flyte.PodTemplate, @@ -133,15 +170,27 @@ def _serialized_pod_spec( limits, requests = {}, {} resources = get_proto_resources(app_env.resources) + extended_resources = get_proto_extended_resources(app_env.resources) if resources: for resource in resources.limits: limits[_sanitize_resource_name(resource)] = resource.value for resource in resources.requests: requests[_sanitize_resource_name(resource)] = resource.value + # Add MIG resources if GPU partitions are specified + mig_prefix = app_env.resources.gpu_partition_resource_prefix if app_env.resources else None + # Get device quantity from resources + device = app_env.resources.get_device() if app_env.resources else None + device_quantity = device.quantity if device else None + mig_resources = _get_mig_resources_from_extended_resources( + extended_resources, device_quantity, mig_prefix + ) + requests.update(mig_resources) + limits.update(mig_resources) + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) - if limits or requests: + if limits or requests or mig_resources: container.resources = resource_requirements if app_env.env_vars: diff --git a/tests/flyte/internal/runtime/test_task_serde.py b/tests/flyte/internal/runtime/test_task_serde.py index 3c71dfd13..e937f972c 100644 --- a/tests/flyte/internal/runtime/test_task_serde.py +++ b/tests/flyte/internal/runtime/test_task_serde.py @@ -16,6 +16,7 @@ import flyte from flyte import PodTemplate +from flyte._internal.runtime.resources_serde import get_proto_extended_resources from flyte._internal.runtime.task_serde import ( _get_k8s_pod, _get_urun_container, @@ -242,7 +243,8 @@ async def t1(a: int, b: str) -> str: assert isinstance(proto_task, tasks_pb2.TaskTemplate) # Check k8s_pod - k8s_pod = _get_k8s_pod(_get_urun_container(context, t1), pod_template1) + extended_resources = get_proto_extended_resources(t1.resources) + k8s_pod = _get_k8s_pod(_get_urun_container(context, t1), pod_template1, t1, extended_resources) assert proto_task.k8s_pod == k8s_pod assert proto_task.k8s_pod.metadata.labels == {"foo": "bar"} assert proto_task.k8s_pod.metadata.annotations == {"baz": "qux"}