Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/flyte/_internal/runtime/resources_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
"V6E": "tpu-v6e-slice",
}

# Default prefix for GPU partition (MIG) resources
DEFAULT_GPU_PARTITION_RESOURCE_PREFIX = "nvidia.com/mig"

_DeviceClassToProto: Dict[DeviceClass, "tasks_pb2.GPUAccelerator.DeviceClass"] = {
"GPU": tasks_pb2.GPUAccelerator.NVIDIA_GPU,
"TPU": tasks_pb2.GPUAccelerator.GOOGLE_TPU,
Expand Down Expand Up @@ -85,8 +88,10 @@ def _get_disk_resource_entry(disk: str) -> tasks_pb2.Resources.ResourceEntry:

def get_proto_extended_resources(resources: Resources | None) -> Optional[tasks_pb2.ExtendedResources]:
"""
TODO Implement partitioning logic string handling for GPU
:param resources:
Get extended resources (GPU accelerator, shared memory) for the task.

:param resources: Resources object containing GPU and shared memory configuration
:return: ExtendedResources protobuf or None if no extended resources are configured
"""
if resources is None:
return None
Expand Down Expand Up @@ -128,7 +133,11 @@ def _convert_resources_to_resource_entries(
if resources.gpu is not None:
device = resources.get_device()
if device is not None:
request_entries.append(_get_gpu_resource_entry(device.quantity))
if device.partition is None:
# Only add standard GPU resource if NO partition
# Partitioned GPUs (MIG) are handled separately at Pod spec creation
if device.partition is None:
request_entries.append(_get_gpu_resource_entry(device.quantity))

if resources.disk is not None:
request_entries.append(_get_disk_resource_entry(resources.disk))
Expand Down
74 changes: 68 additions & 6 deletions src/flyte/_internal/runtime/task_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import typing
from datetime import timedelta
from typing import Optional, cast
from typing import Dict, Optional, cast

from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
Expand All @@ -23,7 +23,7 @@
from ... import ReusePolicy
from ..._retry import RetryStrategy
from ..._timeout import TimeoutType, timeout_from_request
from .resources_serde import get_proto_extended_resources, get_proto_resources
from .resources_serde import DEFAULT_GPU_PARTITION_RESOURCE_PREFIX, get_proto_extended_resources, get_proto_resources
from .reuse import add_reusable
from .types_serde import transform_native_to_typed_interface

Expand Down Expand Up @@ -124,7 +124,8 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
extra_config: typing.Dict[str, str] = {}

if task.pod_template and not isinstance(task.pod_template, str):
pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
extended_resources = get_proto_extended_resources(task.resources)
pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template, task, extended_resources)
extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
container = None
else:
Expand Down Expand Up @@ -264,10 +265,61 @@ def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]:
def _get_mig_resources_from_extended_resources(
extended_resources: Optional[tasks_pb2.ExtendedResources],
device_quantity: Optional[int] = None,
mig_resource_prefix: Optional[str] = None,
) -> Dict[str, str]:
"""
Generate MIG-specific resources from GPUAccelerator partition info.

When a GPU has a partition_size specified, generate a custom resource name
for that partition instead of using the generic GPU resource. This resource
will be added to both requests and limits in the Pod spec.

Example:
If partition="1g.10gb" and prefix="nvidia.com/mig", returns:
{"nvidia.com/mig-1g.10gb": "1"}

:param extended_resources: The extended resources containing GPU accelerator info
:param device_quantity: The quantity of GPUs/partitions requested
:param mig_resource_prefix: Custom MIG resource prefix (defaults to "nvidia.com/mig").
Can be overridden via Resources.gpu_partition_resource_prefix
:return: Dict mapping MIG resource name to quantity (e.g., {"nvidia.com/mig-1g.10gb": "1"})
"""
mig_resources = Dict[str, str] = {}

if extended_resources is None or not extended_resources.gpu_accelerator:
return mig_resources

gpu_accel = extended_resources.gpu_accelerator
partition = gpu_accel.partition_size

if not partition:
return mig_resources

# Default to "nvidia.com/mig" if not specified
prefix = mig_resource_prefix if mig_resource_prefix is not None else DEFAULT_GPU_PARTITION_RESOURCE_PREFIX
resource_name = f"{prefix}-{partition}"

quantity = device_quantity if device_quantity is not None else 1
mig_resources[resource_name] = str(quantity)

return mig_resources


def _get_k8s_pod(
primary_container: tasks_pb2.Container,
pod_template: PodTemplate,
task_template: TaskTemplate,
extended_resources: Optional[tasks_pb2.ExtendedResources],
) -> Optional[tasks_pb2.K8sPod]:
"""
Get the K8sPod representation of the task template.
:param task: The task to convert.
:param primary_container: The primary container to use.
:param pod_template: The pod template to use.
:param task_template: The task template containing resources configuration.
:param extended_resources: The extended resources (GPU accelerator, shared memory).
:return: The K8sPod representation of the task template.
"""
from kubernetes.client import ApiClient, V1PodSpec
Expand Down Expand Up @@ -307,8 +359,18 @@ def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTempla
for resource in primary_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value

# Add MIG resources if GPU partitions are specified
mig_prefix = task_template.resources.gpu_partition_resource_prefix if task_template.resources else None
# Get device quantity from resources
device = task_template.resources.get_device() if task_template.resources else None
device_quantity = device.quantity if device else None
mig_resources = _get_mig_resources_from_extended_resources(extended_resources, device_quantity, mig_prefix)
# Add MIG resources to both requests and limits
requests.update(mig_resources)
limits.update(mig_resources)

resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
if len(limits) > 0 or len(requests) > 0 or mig_resources:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements

Expand Down
4 changes: 4 additions & 0 deletions src/flyte/_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,17 @@ def my_task() -> int:
:param disk: The amount of disk to allocate to the task. This is a string of the form "10GiB".
:param shm: The amount of shared memory to allocate to the task. This is a string of the form "10GiB" or "auto".
If "auto", then the shared memory will be set to max amount of shared memory available on the node.
:param gpu_partition_resource_prefix: Optional override for the GPU partition (MIG) resource name prefix.
Defaults to "nvidia.com/mig" when a GPU partition is specified. Only applies when GPU partition is requested.
For example, set to "custom.company.com/mig" to override the resource name prefix.
"""

cpu: Union[CPUBaseType, Tuple[CPUBaseType, CPUBaseType], None] = None
memory: Union[str, Tuple[str, str], None] = None
gpu: Union[Accelerators, int, Device, None] = None
disk: Union[str, None] = None
shm: Union[str, Literal["auto"], None] = None
gpu_partition_resource_prefix: Optional[str] = None

def __post_init__(self):
if isinstance(self.cpu, tuple):
Expand Down
53 changes: 51 additions & 2 deletions src/flyte/app/_runtime/app_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from copy import deepcopy
from dataclasses import replace
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

from flyteidl2.app import app_definition_pb2
from flyteidl2.common import runtime_version_pb2
Expand Down Expand Up @@ -83,6 +83,43 @@ def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


def _get_mig_resources_from_extended_resources(
extended_resources: Optional[tasks_pb2.ExtendedResources],
device_quantity: Optional[int] = None,
mig_resource_prefix: Optional[str] = None,
) -> Dict[str, str]:
"""
Generate MIG-specific resources from GPUAccelerator partition info.

When a GPU has a partition_size specified, generate a custom resource name
for that partition instead of using the generic GPU resource.

:param extended_resources: The extended resources containing GPU accelerator info
:param mig_resource_prefix: Custom MIG resource prefix (defaults to "nvidia.com/mig")
:param device_quantity: The quantity of GPUs/partitions requested
:return: Dict mapping MIG resource name to quantity (e.g., {"nvidia.com/mig-1g. 5gb": "1"})
"""
mig_resources: Dict[str, str] = {}

if extended_resources is None or not extended_resources.gpu_accelerator:
return mig_resources

gpu_accel = extended_resources.gpu_accelerator
partition = gpu_accel.partition_size

if not partition:
return mig_resources

# Default to "nvidia.com/mig" if not specified
prefix = mig_resource_prefix if mig_resource_prefix is not None else "nvidia.com/mig"
resource_name = f"{prefix}-{partition}"

quantity = device_quantity if device_quantity is not None else 1
mig_resources[resource_name] = str(quantity)

return mig_resources


def _serialized_pod_spec(
app_env: AppEnvironment,
pod_template: flyte.PodTemplate,
Expand Down Expand Up @@ -133,15 +170,27 @@ def _serialized_pod_spec(

limits, requests = {}, {}
resources = get_proto_resources(app_env.resources)
extended_resources = get_proto_extended_resources(app_env.resources)
if resources:
for resource in resources.limits:
limits[_sanitize_resource_name(resource)] = resource.value
for resource in resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value

# Add MIG resources if GPU partitions are specified
mig_prefix = app_env.resources.gpu_partition_resource_prefix if app_env.resources else None
# Get device quantity from resources
device = app_env.resources.get_device() if app_env.resources else None
device_quantity = device.quantity if device else None
mig_resources = _get_mig_resources_from_extended_resources(
extended_resources, device_quantity, mig_prefix
)
requests.update(mig_resources)
limits.update(mig_resources)

resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)

if limits or requests:
if limits or requests or mig_resources:
container.resources = resource_requirements

if app_env.env_vars:
Expand Down
4 changes: 3 additions & 1 deletion tests/flyte/internal/runtime/test_task_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import flyte
from flyte import PodTemplate
from flyte._internal.runtime.resources_serde import get_proto_extended_resources
from flyte._internal.runtime.task_serde import (
_get_k8s_pod,
_get_urun_container,
Expand Down Expand Up @@ -242,7 +243,8 @@ async def t1(a: int, b: str) -> str:
assert isinstance(proto_task, tasks_pb2.TaskTemplate)

# Check k8s_pod
k8s_pod = _get_k8s_pod(_get_urun_container(context, t1), pod_template1)
extended_resources = get_proto_extended_resources(t1.resources)
k8s_pod = _get_k8s_pod(_get_urun_container(context, t1), pod_template1, t1, extended_resources)
assert proto_task.k8s_pod == k8s_pod
assert proto_task.k8s_pod.metadata.labels == {"foo": "bar"}
assert proto_task.k8s_pod.metadata.annotations == {"baz": "qux"}
Expand Down
Loading