Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from flytekit.configuration import ImageConfig
from flytekit.configuration.default_images import DefaultImages
from flytekit.constants import CopyFileDetection
from flytekit.interaction.click_types import key_value_callback
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.interaction.click_types import key_value_callback, resource_callback
from flytekit.loggers import logger
from flytekit.tools import repo

Expand Down Expand Up @@ -134,6 +135,22 @@
callback=key_value_callback,
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
)
@click.option(
"--resource-requests",
required=False,
type=str,
callback=resource_callback,
help="Override default task resource requests for tasks that have no statically defined resource requests in their task decorator. "
"""Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'""",
)
@click.option(
"--resource-limits",
required=False,
type=str,
callback=resource_callback,
help="Override default task resource limits for tasks that have no statically defined resource limits in their task decorator. "
"""Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'""",
)
@click.option(
"--skip-errors",
"--skip-error",
Expand Down Expand Up @@ -161,6 +178,8 @@ def register(
dry_run: bool,
activate_launchplans: bool,
env: typing.Optional[typing.Dict[str, str]],
resource_requests: typing.Optional[Resources],
resource_limits: typing.Optional[Resources],
skip_errors: bool,
):
"""
Expand Down Expand Up @@ -225,6 +244,9 @@ def register(
package_or_module=package_or_module,
remote=remote,
env=env,
default_resources=ResourceSpec(
requests=resource_requests or Resources(), limits=resource_limits or Resources()
),
dry_run=dry_run,
activate_launchplans=activate_launchplans,
skip_errors=skip_errors,
Expand Down
28 changes: 28 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from flytekit.core.artifact import ArtifactQuery
from flytekit.core.base_task import PythonTask
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.exceptions.system import FlyteSystemException
Expand All @@ -51,6 +52,7 @@
FlyteLiteralConverter,
key_value_callback,
labels_callback,
resource_callback,
)
from flytekit.interaction.string_literals import literal_string_repr
from flytekit.loggers import logger
Expand Down Expand Up @@ -208,6 +210,28 @@ class RunLevelParams(PyFlyteParams):
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
)
)
resource_requests: typing.Optional[Resources] = make_click_option_field(
click.Option(
param_decls=["--resource-requests"],
required=False,
show_default=True,
type=str,
callback=resource_callback,
help="This overrides default task resource requests for tasks that have no statically defined resource requests in their task decorator. "
"""Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'""",
)
)
resource_limits: typing.Optional[Resources] = make_click_option_field(
click.Option(
param_decls=["--resource-limits"],
required=False,
show_default=True,
type=str,
callback=resource_callback,
help="This overrides default task resource limits for tasks that have no statically defined resource limits in their task decorator. "
"""Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'""",
)
)
tags: typing.List[str] = make_click_option_field(
click.Option(
param_decls=["--tags", "--tag"],
Expand Down Expand Up @@ -756,6 +780,10 @@ def _run(*args, **kwargs):
source_path=run_level_params.computed_params.project_root,
module_name=run_level_params.computed_params.module,
fast_package_options=fast_package_options,
default_resources=ResourceSpec(
requests=run_level_params.resource_requests or Resources(),
limits=run_level_params.resource_limits or Resources(),
),
)

run_remote(
Expand Down
7 changes: 7 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
from flytekit.configuration import internal as _internal
from flytekit.configuration.default_images import DefaultImages
from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists
from flytekit.core.resources import ResourceSpec
from flytekit.image_spec import ImageSpec
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.loggers import logger
Expand Down Expand Up @@ -805,6 +806,8 @@ class SerializationSettings(DataClassJsonMixin):
version (str): The version (if any) with which to register entities under.
image_config (ImageConfig): The image config used to define task container images.
env (Optional[Dict[str, str]]): Environment variables injected into task container definitions.
default_resources (Optional[ResourceSpec]): The resources to request for the task - this is useful
if users need to override the default resource spec of an entity at registration time.
flytekit_virtualenv_root (Optional[str]): During out of container serialize the absolute path of the flytekit
virtualenv at serialization time won't match the in-container value at execution time. This optional value
is used to provide the in-container virtualenv path
Expand All @@ -823,6 +826,7 @@ class SerializationSettings(DataClassJsonMixin):
domain: typing.Optional[str] = None
version: typing.Optional[str] = None
env: Optional[Dict[str, str]] = None
default_resources: Optional[ResourceSpec] = None
git_repo: Optional[str] = None
python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER
flytekit_virtualenv_root: Optional[str] = None
Expand Down Expand Up @@ -897,6 +901,7 @@ def new_builder(self) -> Builder:
version=self.version,
image_config=self.image_config,
env=self.env.copy() if self.env else None,
default_resources=self.default_resources,
git_repo=self.git_repo,
flytekit_virtualenv_root=self.flytekit_virtualenv_root,
python_interpreter=self.python_interpreter,
Expand Down Expand Up @@ -948,6 +953,7 @@ class Builder(object):
version: str
image_config: ImageConfig
env: Optional[Dict[str, str]] = None
default_resources: Optional[ResourceSpec] = None
git_repo: Optional[str] = None
flytekit_virtualenv_root: Optional[str] = None
python_interpreter: Optional[str] = None
Expand All @@ -965,6 +971,7 @@ def build(self) -> SerializationSettings:
version=self.version,
image_config=self.image_config,
env=self.env,
default_resources=self.default_resources,
git_repo=self.git_repo,
flytekit_virtualenv_root=self.flytekit_virtualenv_root,
python_interpreter=self.python_interpreter,
Expand Down
11 changes: 11 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,17 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
if elem:
env.update(elem)

# Override the task's resource spec if it was not set statically in the task definition

def _resources_unspecified(resources: ResourceSpec) -> bool:
return resources == ResourceSpec(
requests=Resources(),
limits=Resources(),
)

if isinstance(settings.default_resources, ResourceSpec) and _resources_unspecified(self.resources):
self._resources = settings.default_resources

# Add runtime dependencies into environment
if isinstance(self.container_image, ImageSpec) and self.container_image.runtime_packages:
runtime_packages = " ".join(self.container_image.runtime_packages)
Expand Down
28 changes: 28 additions & 0 deletions flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset
from flytekit.core.artifact import ArtifactQuery
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.resources import Resources
from flytekit.core.type_engine import TypeEngine
from flytekit.models.types import SimpleType
from flytekit.remote.remote_fs import FlytePathResolver
Expand Down Expand Up @@ -84,6 +85,33 @@ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typi
return result


def resource_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[Resources]:
"""
Callback for click to parse a resource from a comma-separated string of the form 'cpu=1,mem=2Gi' for example
"""
if not value:
return None

items = value.split(",")
_allowed_keys = Resources.__annotations__.keys()
result = {}
for item in items:
kv_split = item.split("=")
if len(kv_split) != 2:
raise click.BadParameter(
f"Expected comma separated key-value pairs of the form 'key1=value1,key2=value2,...', got '{item}'"
)
k = kv_split[0].strip()
v = kv_split[1].strip()
if k not in _allowed_keys:
raise click.BadParameter(f"Expected key to be one of {list(_allowed_keys)}, but got '{k}'")
if k in result:
raise click.BadParameter(f"Expected unique keys {list(_allowed_keys)}, but got '{k}' multiple times")
result[k] = v

return Resources(**result)


class DirParamType(click.ParamType):
name = "directory path"

Expand Down
4 changes: 4 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec
from flytekit.core.resources import ResourceSpec
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import LiteralsResolver, TypeEngine, strict_type_hint_matching
Expand Down Expand Up @@ -1326,6 +1327,7 @@ def register_script(
source_path: typing.Optional[str] = None,
module_name: typing.Optional[str] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
default_resources: typing.Optional[ResourceSpec] = None,
fast_package_options: typing.Optional[FastPackageOptions] = None,
) -> typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan, ReferenceEntity]:
"""
Expand All @@ -1342,6 +1344,7 @@ def register_script(
:param source_path: The root of the project path
:param module_name: the name of the module
:param envs: Environment variables to be passed to the serialization
:param default_resources: Default resources to be passed to the serialization. These override the resource spec for any tasks that have no statically defined resource requests and limits.
:param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False.
:return:
"""
Expand Down Expand Up @@ -1380,6 +1383,7 @@ def register_script(
image_config=image_config,
git_repo=_get_git_repo_url(source_path),
env=envs,
default_resources=default_resources,
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir=destination_dir,
Expand Down
3 changes: 3 additions & 0 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flytekit.constants import CopyFileDetection
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager, FlyteEntities
from flytekit.core.resources import ResourceSpec
from flytekit.loggers import logger
from flytekit.models import launch_plan, task
from flytekit.models.core.identifier import Identifier
Expand Down Expand Up @@ -252,6 +253,7 @@ def register(
remote: FlyteRemote,
copy_style: CopyFileDetection,
env: typing.Optional[typing.Dict[str, str]],
default_resources: typing.Optional[ResourceSpec],
dry_run: bool = False,
activate_launchplans: bool = False,
skip_errors: bool = False,
Expand All @@ -274,6 +276,7 @@ def register(
image_config=image_config,
fast_serialization_settings=None, # should probably add incomplete fast settings
env=env,
default_resources=default_resources,
)

if not version and copy_style == CopyFileDetection.NO_COPY:
Expand Down
98 changes: 98 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import uuid
import pytest
from unittest import mock
import random
import string
from dataclasses import asdict, dataclass

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow
from flytekit.configuration import Config, ImageConfig, SerializationSettings
from flytekit.core.launch_plan import reference_launch_plan
from flytekit.core.task import reference_task
from flytekit.core.workflow import reference_workflow
from flytekit.models import task as task_models
from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task
from flytekit.remote.remote import FlyteRemote
Expand Down Expand Up @@ -1252,3 +1255,98 @@ def test_register_wf_twice(register):
]
)
assert out.returncode == 0


def test_register_wf_with_resource_requests_override(register):
# Save the version here to retrieve the created task later
version = str(uuid.uuid4())
# Register the workflow with overridden default resources
out = subprocess.run(
[
"pyflyte",
"--verbose",
"-c",
CONFIG,
"register",
"--resource-requests",
"cpu=1300m,mem=1100Mi",
"--image",
IMAGE,
"--project",
PROJECT,
"--domain",
DOMAIN,
"--version",
version,
MODULE_PATH / "hello_world.py",
]
)
assert out.returncode == 0

# Retrieve the created task
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
task = remote.fetch_task(name="basic.hello_world.say_hello", version=version)
assert task.template.container is not None
assert task.template.container.resources == task_models.Resources(
requests=[
task_models.Resources.ResourceEntry(
name=task_models.Resources.ResourceName.CPU,
value="1300m",
),
task_models.Resources.ResourceEntry(
name=task_models.Resources.ResourceName.MEMORY,
value="1100Mi",
),
],
limits=[],
)


def test_run_wf_with_resource_requests_override(register):
# Save the execution id here to retrieve the created execution later
prefix = random.choice(string.ascii_lowercase)
short_random_part = uuid.uuid4().hex[:8]
execution_id = f"{prefix}{short_random_part}"
# Register the workflow with overridden default resources
out = subprocess.run(
[
"pyflyte",
"--verbose",
"-c",
CONFIG,
"run",
"--remote",
"--resource-requests",
"cpu=500m,mem=1Gi",
"--project",
PROJECT,
"--domain",
DOMAIN,
"--name",
execution_id,
MODULE_PATH / "hello_world.py",
"my_wf"
]
)
assert out.returncode == 0

# Retrieve the created task
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution)
version = execution.spec.launch_plan.version
task = remote.fetch_task(name="basic.hello_world.say_hello", version=version)
assert task.template.container is not None
assert task.template.container.resources == task_models.Resources(
requests=[
task_models.Resources.ResourceEntry(
name=task_models.Resources.ResourceName.CPU,
value="500m",
),
task_models.Resources.ResourceEntry(
name=task_models.Resources.ResourceName.MEMORY,
value="1Gi",
),
],
limits=[],
)
Loading
Loading