Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flytekit/clis/sdk_in_container/pyflyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flytekit.clis.sdk_in_container.package import package
from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.run import run
from flytekit.clis.sdk_in_container.runv2 import runv2
from flytekit.clis.sdk_in_container.serialize import serialize
from flytekit.clis.sdk_in_container.serve import serve
from flytekit.clis.sdk_in_container.utils import ErrorHandlingCommand, validate_package
Expand Down Expand Up @@ -96,6 +97,7 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: int):
main.add_command(info)
main.add_command(get)
main.add_command(execute)
main.add_command(runv2)
main.epilog

get_plugin().configure_pyflyte_cli(main)
Expand Down
130 changes: 130 additions & 0 deletions flytekit/clis/sdk_in_container/runv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import importlib.util
import json
import sys
from typing import Any, Dict

import click
import flyte

import flytekit
from flytekit.migration.task import task_shim
from flytekit.migration.workflow import workflow_shim


def _load_module_from_path(path: str):
spec = importlib.util.spec_from_file_location("user_module", path)
if spec is None or spec.loader is None:
raise click.UsageError(f"Cannot load module from: {path}")
mod = importlib.util.module_from_spec(spec)
sys.modules["user_module"] = mod
spec.loader.exec_module(mod)
return mod


def _parse_kv(pairs: tuple[str, ...]) -> Dict[str, Any]:
out: Dict[str, Any] = {}
for kv in pairs:
if "=" not in kv:
raise click.UsageError(f"Bad input '{kv}', expected key=value")
k, v = kv.split("=", 1)
# naive coercion
if v.lower() in {"true", "false"}:
out[k] = v.lower() == "true"
else:
try:
out[k] = int(v) if "." not in v else float(v)
except ValueError:
out[k] = v
return out


def _run_remote(entity, inputs):
if hasattr(flyte, "remote") and callable(getattr(flyte, "remote")):
r = flyte.remote()
try:
return r.run(entity, **inputs)
except TypeError:
return r.run(entity, inputs=inputs)
elif hasattr(flyte, "submit"):
return flyte.submit(entity, **inputs)
elif hasattr(flyte, "Runner") and hasattr(flyte.Runner, "remote"):
rr = flyte.Runner.remote()
try:
return rr.run(entity, inputs)
except TypeError:
return rr.run(entity, **inputs)
else:
raise click.UsageError(
"Remote execution is not available in this flyte-sdk build. Please upgrade flyte-sdk or configure a remote backend."
)


@click.command("runv2", context_settings={"ignore_unknown_options": True})
@click.option(
"--remote",
is_flag=True,
default=False,
help="Submit via Flyte 2 remote backend if configured; otherwise run locally.",
)
@click.argument("pyfile", type=click.Path(exists=True))
@click.argument("entity_name")
@click.option("-i", "--input", "inputs_kv", multiple=True, help="key=value pairs")
@click.option("--config", type=click.Path(exists=True), help="Flyte 2 SDK config file")
def runv2(pyfile: str, entity_name: str, inputs_kv: tuple[str, ...], config: str | None, remote: bool):
"""
pyflyte runv2 xx.py <workflow_or_task_name> -i a=1 -i b=hello

Loads the module, applies v2 shims to flytekit decorators, and executes
the selected entity via the Flyte 2 runtime (flyte.run).
"""
# init Flyte 2
if config:
flyte.init_from_config(config)
else:
flyte.init()

flytekit.task = task_shim
flytekit.workflow = workflow_shim

spec = importlib.util.spec_from_file_location("user_module", pyfile)
mod = importlib.util.module_from_spec(spec)
sys.modules["user_module"] = mod
spec.loader.exec_module(mod) # type: ignore

entity = getattr(mod, entity_name)
inputs = {}
for kv in inputs_kv:
k, v = kv.split("=", 1)
inputs[k] = (
(v.lower() == "true")
if v.lower() in ("true", "false")
else (float(v) if "." in v else (int(v) if v.isdigit() else v))
)

if remote:
out = _run_remote(entity, inputs)
else:
out = flyte.run(entity, **inputs)

value = out
for name in ("result", "output", "outputs"):
attr = getattr(value, name, None)
if attr is None:
continue
try:
value = attr() if callable(attr) else attr
except TypeError:
value = attr

if isinstance(value, dict) and len(value) == 1:
try:
value = next(iter(value.values()))
except Exception:
pass

try:
click.echo(json.dumps({"result": value}, default=str))
except Exception:
click.echo(str(value))
140 changes: 140 additions & 0 deletions flytekit/migration/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from __future__ import annotations

import datetime
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union

from flyte import Image, Resources, TaskEnvironment
from flyte._doc import Documentation as V2Docs
from flyte._task import AsyncFunctionTaskTemplate, P, R

import flytekit
from flytekit.core import launch_plan, workflow
from flytekit.core.base_task import T, TaskResolverMixin
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.task import FuncOut
from flytekit.deck import DeckField
from flytekit.extras.accelerators import BaseAccelerator


def _to_v2_resources(req: Optional[flytekit.Resources], lim: Optional[flytekit.Resources]) -> Optional[Resources]:
if not req and not lim:
return None

# Pick requests first, then fall back to limits if requests missing.
def pick(getter: Callable[[flytekit.Resources], Any], fallback_getter: Callable[[flytekit.Resources], Any]):
if req and getter(req) is not None:
return getter(req)
if lim and fallback_getter(lim) is not None:
return fallback_getter(lim)
return None

cpu = pick(lambda r: r.cpu, lambda r: r.cpu)
mem = pick(lambda r: r.mem, lambda r: r.mem)
gpu = pick(lambda r: r.gpu, lambda r: r.gpu)
# Flyte-SDK Resources accepts cpu as float/str, memory as str like "800Mi"
return Resources(cpu=cpu, memory=mem, gpu=gpu)


def _to_v2_image(container_image: Optional[Union[str, flytekit.ImageSpec]]) -> Image:
if isinstance(container_image, flytekit.ImageSpec):
img = Image.from_debian_base()
if container_image.apt_packages:
img = img.with_apt_packages(*container_image.apt_packages)
pip_packages = []
pip_packages.append("flyte")
pip_packages.append("flytekit")
if container_image.packages:
pip_packages.extend(container_image.packages)
return img.with_pip_packages(*pip_packages)
if isinstance(container_image, str):
return Image.from_base(container_image).with_pip_packages("flyte", "flytekit")
# default
return Image.from_debian_base().with_pip_packages("flyte", "flytekit")


def task_shim(
_task_function: Optional[Callable[P, FuncOut]] = None,
task_config: Optional[T] = None,
cache: Union[bool, flytekit.Cache] = False,
retries: int = 0,
interruptible: Optional[bool] = None,
deprecated: str = "",
timeout: Union[datetime.timedelta, int] = 0,
container_image: Optional[Union[str, flytekit.ImageSpec]] = None,
environment: Optional[Dict[str, str]] = None,
requests: Optional[flytekit.Resources] = None,
limits: Optional[flytekit.Resources] = None,
secret_requests: Optional[List[flytekit.Secret]] = None,
execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT,
node_dependency_hints: Optional[
Iterable[
Union[
flytekit.PythonFunctionTask,
launch_plan.LaunchPlan,
workflow.WorkflowBase,
]
]
] = None,
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[flytekit.Documentation] = None,
disable_deck: Optional[bool] = None,
enable_deck: Optional[bool] = None,
deck_fields: Optional[Tuple[DeckField, ...]] = (
DeckField.SOURCE_CODE,
DeckField.DEPENDENCIES,
DeckField.TIMELINE,
DeckField.INPUT,
DeckField.OUTPUT,
),
pod_template: Optional[flytekit.PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
shared_memory: Optional[Union[Literal[True], str]] = None,
resources: Optional[Resources] = None, # explicit v2 resources passthrough
labels: Optional[dict[str, str]] = None,
annotations: Optional[dict[str, str]] = None,
**kwargs,
) -> Union[AsyncFunctionTaskTemplate, Callable[[Callable[P, R]], AsyncFunctionTaskTemplate]]:
"""
Decorator that mimics flytekit.task but registers a Flyte 2 task under the hood.
Returns a decorator if called with no function; otherwise returns the wrapped task.
"""

# Build V2 image/resources
image = _to_v2_image(container_image)
if resources is None:
resources = _to_v2_resources(requests, limits)

v2_docs = V2Docs(description=getattr(docs, "short_description", None)) if docs else None

# cache mapping
cache_mode: Literal["enabled", "disabled"]
cache_mode = "enabled" if (cache is True or str(cache).lower() == "true") else "disabled"

# PodTemplate passthrough: prefer name, else object
pod_tpl = pod_template_name or (pod_template and pod_template.pod_spec and pod_template) or None

env = TaskEnvironment(
name="flytekit",
resources=resources or Resources(cpu=0.8, memory="800Mi"),
image=image,
cache=cache_mode,
plugin_config=task_config,
env=environment,
)

def _decorator(fn: Callable[P, R]) -> AsyncFunctionTaskTemplate:
# You can add retries, timeout, accelerator, secrets mapping here as needed
return env.task(
retries=retries,
pod_template=pod_tpl,
docs=v2_docs,
timeout=timeout if isinstance(timeout, int) else int(timeout.total_seconds()) if timeout else 0,
# You may add accelerator, secrets, interruptible, etc. when Flyte 2 exposes them
)(fn)

# Support both @task and @task()
if _task_function is not None:
return _decorator(_task_function)
return _decorator
34 changes: 34 additions & 0 deletions flytekit/migration/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from typing import Any, Callable, Optional

from flyte import Image, Resources, TaskEnvironment
from flyte._doc import Documentation as V2Docs


def workflow_shim(
_fn: Optional[Callable[..., Any]] = None,
*,
image: Optional[Image] = None,
resources: Optional[Resources] = None,
docs: Optional[str] = None,
):
"""
Simple replacement for @flytekit.workflow that wraps the Python function
as a Flyte 2 task (pure-Python orchestration).
"""
env = TaskEnvironment(
name="flytekit",
resources=resources or Resources(cpu=0.8, memory="800Mi"),
image=image or Image.from_debian_base().with_pip_packages("flyte", "flytekit"),
)
v2_docs = V2Docs(description=docs) if docs else None

def _decorator(fn: Callable[..., Any]):
# Turn the "workflow" into a task that calls user code directly.
# In Flyte 2 you orchestrate with Python (loops/await/gather inside fn).
return env.task(docs=v2_docs)(fn)

if _fn is not None:
return _decorator(_fn)
return _decorator
Loading