diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index a492e1cba8..52cdc1b0f8 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -17,6 +17,7 @@ from flytekit.clis.sdk_in_container.package import package from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run +from flytekit.clis.sdk_in_container.runv2 import runv2 from flytekit.clis.sdk_in_container.serialize import serialize from flytekit.clis.sdk_in_container.serve import serve from flytekit.clis.sdk_in_container.utils import ErrorHandlingCommand, validate_package @@ -96,6 +97,7 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: int): main.add_command(info) main.add_command(get) main.add_command(execute) +main.add_command(runv2) main.epilog get_plugin().configure_pyflyte_cli(main) diff --git a/flytekit/clis/sdk_in_container/runv2.py b/flytekit/clis/sdk_in_container/runv2.py new file mode 100644 index 0000000000..3b52c00b3b --- /dev/null +++ b/flytekit/clis/sdk_in_container/runv2.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import importlib.util +import json +import sys +from typing import Any, Dict + +import click +import flyte + +import flytekit +from flytekit.migration.task import task_shim +from flytekit.migration.workflow import workflow_shim + + +def _load_module_from_path(path: str): + spec = importlib.util.spec_from_file_location("user_module", path) + if spec is None or spec.loader is None: + raise click.UsageError(f"Cannot load module from: {path}") + mod = importlib.util.module_from_spec(spec) + sys.modules["user_module"] = mod + spec.loader.exec_module(mod) + return mod + + +def _parse_kv(pairs: tuple[str, ...]) -> Dict[str, Any]: + out: Dict[str, Any] = {} + for kv in pairs: + if "=" not in kv: + raise click.UsageError(f"Bad input '{kv}', expected key=value") + k, v = kv.split("=", 1) + # naive coercion + if v.lower() in {"true", "false"}: + out[k] = v.lower() == "true" + else: + try: + out[k] = int(v) if "." not in v else float(v) + except ValueError: + out[k] = v + return out + + +def _run_remote(entity, inputs): + if hasattr(flyte, "remote") and callable(getattr(flyte, "remote")): + r = flyte.remote() + try: + return r.run(entity, **inputs) + except TypeError: + return r.run(entity, inputs=inputs) + elif hasattr(flyte, "submit"): + return flyte.submit(entity, **inputs) + elif hasattr(flyte, "Runner") and hasattr(flyte.Runner, "remote"): + rr = flyte.Runner.remote() + try: + return rr.run(entity, inputs) + except TypeError: + return rr.run(entity, **inputs) + else: + raise click.UsageError( + "Remote execution is not available in this flyte-sdk build. Please upgrade flyte-sdk or configure a remote backend." + ) + + +@click.command("runv2", context_settings={"ignore_unknown_options": True}) +@click.option( + "--remote", + is_flag=True, + default=False, + help="Submit via Flyte 2 remote backend if configured; otherwise run locally.", +) +@click.argument("pyfile", type=click.Path(exists=True)) +@click.argument("entity_name") +@click.option("-i", "--input", "inputs_kv", multiple=True, help="key=value pairs") +@click.option("--config", type=click.Path(exists=True), help="Flyte 2 SDK config file") +def runv2(pyfile: str, entity_name: str, inputs_kv: tuple[str, ...], config: str | None, remote: bool): + """ + pyflyte runv2 xx.py -i a=1 -i b=hello + + Loads the module, applies v2 shims to flytekit decorators, and executes + the selected entity via the Flyte 2 runtime (flyte.run). + """ + # init Flyte 2 + if config: + flyte.init_from_config(config) + else: + flyte.init() + + flytekit.task = task_shim + flytekit.workflow = workflow_shim + + spec = importlib.util.spec_from_file_location("user_module", pyfile) + mod = importlib.util.module_from_spec(spec) + sys.modules["user_module"] = mod + spec.loader.exec_module(mod) # type: ignore + + entity = getattr(mod, entity_name) + inputs = {} + for kv in inputs_kv: + k, v = kv.split("=", 1) + inputs[k] = ( + (v.lower() == "true") + if v.lower() in ("true", "false") + else (float(v) if "." in v else (int(v) if v.isdigit() else v)) + ) + + if remote: + out = _run_remote(entity, inputs) + else: + out = flyte.run(entity, **inputs) + + value = out + for name in ("result", "output", "outputs"): + attr = getattr(value, name, None) + if attr is None: + continue + try: + value = attr() if callable(attr) else attr + except TypeError: + value = attr + + if isinstance(value, dict) and len(value) == 1: + try: + value = next(iter(value.values())) + except Exception: + pass + + try: + click.echo(json.dumps({"result": value}, default=str)) + except Exception: + click.echo(str(value)) diff --git a/flytekit/migration/task.py b/flytekit/migration/task.py new file mode 100644 index 0000000000..14315bad54 --- /dev/null +++ b/flytekit/migration/task.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import datetime +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union + +from flyte import Image, Resources, TaskEnvironment +from flyte._doc import Documentation as V2Docs +from flyte._task import AsyncFunctionTaskTemplate, P, R + +import flytekit +from flytekit.core import launch_plan, workflow +from flytekit.core.base_task import T, TaskResolverMixin +from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.task import FuncOut +from flytekit.deck import DeckField +from flytekit.extras.accelerators import BaseAccelerator + + +def _to_v2_resources(req: Optional[flytekit.Resources], lim: Optional[flytekit.Resources]) -> Optional[Resources]: + if not req and not lim: + return None + + # Pick requests first, then fall back to limits if requests missing. + def pick(getter: Callable[[flytekit.Resources], Any], fallback_getter: Callable[[flytekit.Resources], Any]): + if req and getter(req) is not None: + return getter(req) + if lim and fallback_getter(lim) is not None: + return fallback_getter(lim) + return None + + cpu = pick(lambda r: r.cpu, lambda r: r.cpu) + mem = pick(lambda r: r.mem, lambda r: r.mem) + gpu = pick(lambda r: r.gpu, lambda r: r.gpu) + # Flyte-SDK Resources accepts cpu as float/str, memory as str like "800Mi" + return Resources(cpu=cpu, memory=mem, gpu=gpu) + + +def _to_v2_image(container_image: Optional[Union[str, flytekit.ImageSpec]]) -> Image: + if isinstance(container_image, flytekit.ImageSpec): + img = Image.from_debian_base() + if container_image.apt_packages: + img = img.with_apt_packages(*container_image.apt_packages) + pip_packages = [] + pip_packages.append("flyte") + pip_packages.append("flytekit") + if container_image.packages: + pip_packages.extend(container_image.packages) + return img.with_pip_packages(*pip_packages) + if isinstance(container_image, str): + return Image.from_base(container_image).with_pip_packages("flyte", "flytekit") + # default + return Image.from_debian_base().with_pip_packages("flyte", "flytekit") + + +def task_shim( + _task_function: Optional[Callable[P, FuncOut]] = None, + task_config: Optional[T] = None, + cache: Union[bool, flytekit.Cache] = False, + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[Union[str, flytekit.ImageSpec]] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[flytekit.Resources] = None, + limits: Optional[flytekit.Resources] = None, + secret_requests: Optional[List[flytekit.Secret]] = None, + execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT, + node_dependency_hints: Optional[ + Iterable[ + Union[ + flytekit.PythonFunctionTask, + launch_plan.LaunchPlan, + workflow.WorkflowBase, + ] + ] + ] = None, + task_resolver: Optional[TaskResolverMixin] = None, + docs: Optional[flytekit.Documentation] = None, + disable_deck: Optional[bool] = None, + enable_deck: Optional[bool] = None, + deck_fields: Optional[Tuple[DeckField, ...]] = ( + DeckField.SOURCE_CODE, + DeckField.DEPENDENCIES, + DeckField.TIMELINE, + DeckField.INPUT, + DeckField.OUTPUT, + ), + pod_template: Optional[flytekit.PodTemplate] = None, + pod_template_name: Optional[str] = None, + accelerator: Optional[BaseAccelerator] = None, + pickle_untyped: bool = False, + shared_memory: Optional[Union[Literal[True], str]] = None, + resources: Optional[Resources] = None, # explicit v2 resources passthrough + labels: Optional[dict[str, str]] = None, + annotations: Optional[dict[str, str]] = None, + **kwargs, +) -> Union[AsyncFunctionTaskTemplate, Callable[[Callable[P, R]], AsyncFunctionTaskTemplate]]: + """ + Decorator that mimics flytekit.task but registers a Flyte 2 task under the hood. + Returns a decorator if called with no function; otherwise returns the wrapped task. + """ + + # Build V2 image/resources + image = _to_v2_image(container_image) + if resources is None: + resources = _to_v2_resources(requests, limits) + + v2_docs = V2Docs(description=getattr(docs, "short_description", None)) if docs else None + + # cache mapping + cache_mode: Literal["enabled", "disabled"] + cache_mode = "enabled" if (cache is True or str(cache).lower() == "true") else "disabled" + + # PodTemplate passthrough: prefer name, else object + pod_tpl = pod_template_name or (pod_template and pod_template.pod_spec and pod_template) or None + + env = TaskEnvironment( + name="flytekit", + resources=resources or Resources(cpu=0.8, memory="800Mi"), + image=image, + cache=cache_mode, + plugin_config=task_config, + env=environment, + ) + + def _decorator(fn: Callable[P, R]) -> AsyncFunctionTaskTemplate: + # You can add retries, timeout, accelerator, secrets mapping here as needed + return env.task( + retries=retries, + pod_template=pod_tpl, + docs=v2_docs, + timeout=timeout if isinstance(timeout, int) else int(timeout.total_seconds()) if timeout else 0, + # You may add accelerator, secrets, interruptible, etc. when Flyte 2 exposes them + )(fn) + + # Support both @task and @task() + if _task_function is not None: + return _decorator(_task_function) + return _decorator diff --git a/flytekit/migration/workflow.py b/flytekit/migration/workflow.py new file mode 100644 index 0000000000..ab2b864a80 --- /dev/null +++ b/flytekit/migration/workflow.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional + +from flyte import Image, Resources, TaskEnvironment +from flyte._doc import Documentation as V2Docs + + +def workflow_shim( + _fn: Optional[Callable[..., Any]] = None, + *, + image: Optional[Image] = None, + resources: Optional[Resources] = None, + docs: Optional[str] = None, +): + """ + Simple replacement for @flytekit.workflow that wraps the Python function + as a Flyte 2 task (pure-Python orchestration). + """ + env = TaskEnvironment( + name="flytekit", + resources=resources or Resources(cpu=0.8, memory="800Mi"), + image=image or Image.from_debian_base().with_pip_packages("flyte", "flytekit"), + ) + v2_docs = V2Docs(description=docs) if docs else None + + def _decorator(fn: Callable[..., Any]): + # Turn the "workflow" into a task that calls user code directly. + # In Flyte 2 you orchestrate with Python (loops/await/gather inside fn). + return env.task(docs=v2_docs)(fn) + + if _fn is not None: + return _decorator(_fn) + return _decorator