diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index eb8164f2b0..268ed0cc28 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -810,6 +810,7 @@ class DynamicEntityLaunchCommand(click.RichCommand): LP_LAUNCHER = "lp" TASK_LAUNCHER = "task" + WORKFLOW_LAUNCHER = "workflow" def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs): super().__init__(name=name, help=h, **kwargs) @@ -817,7 +818,7 @@ def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs) self._launcher = launcher self._entity = None - def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]: + def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask, FlyteWorkflow]: if self._entity: return self._entity run_level_params: RunLevelParams = ctx.obj @@ -837,6 +838,12 @@ def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, Fly ) ) entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name) + elif self._launcher == self.WORKFLOW_LAUNCHER: + parts = self._entity_name.split(":") + if len(parts) == 2: + entity = r.fetch_workflow(run_level_params.project, run_level_params.domain, parts[0], parts[1]) + else: + entity = r.fetch_workflow(run_level_params.project, run_level_params.domain, self._entity_name) else: parts = self._entity_name.split(":") if len(parts) == 2: @@ -973,13 +980,20 @@ def list_commands(self, ctx): return [] def get_command(self, ctx, name): - if self._command_name in [self.LAUNCHPLAN_COMMAND, self.WORKFLOW_COMMAND]: + if self._command_name == self.LAUNCHPLAN_COMMAND: return DynamicEntityLaunchCommand( name=name, h=f"Execute a {self._command_name}.", entity_name=name, launcher=DynamicEntityLaunchCommand.LP_LAUNCHER, ) + elif self._command_name == self.WORKFLOW_COMMAND: + return DynamicEntityLaunchCommand( + name=name, + h=f"Execute a {self._command_name}.", + entity_name=name, + launcher=DynamicEntityLaunchCommand.WORKFLOW_LAUNCHER, + ) return DynamicEntityLaunchCommand( name=name, h=f"Execute a {self._command_name}.", diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 3e3209671c..dd8bc64f4a 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -962,3 +962,63 @@ def example_task(flag: bool) -> bool: args, _ = mock_run_remote.call_args inputs = args[4]['flag'] assert inputs == False + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.run.run_remote") +def test_remote_workflow(mock_run_remote, mock_remote): + @task() + def example_task(x: int, y: str) -> str: + return f"{x},{y}" + + @workflow + def example_workflow(x: int, y: str) -> str: + return example_task(x=x, y=y) + + mock_remote_instance = mock.MagicMock() + mock_remote.return_value = mock_remote_instance + mock_remote_instance.fetch_workflow.return_value = example_workflow + + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", "remote-workflow", "some_module.example_workflow", "--x", "42", "--y", "hello"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_remote_instance.fetch_workflow.assert_called_once() + mock_run_remote.assert_called_once() + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.run.run_remote") +def test_remote_workflow_with_version(mock_run_remote, mock_remote): + @task() + def example_task(x: int) -> int: + return x * 2 + + @workflow + def example_workflow(x: int) -> int: + return example_task(x=x) + + mock_remote_instance = mock.MagicMock() + mock_remote.return_value = mock_remote_instance + mock_remote_instance.fetch_workflow.return_value = example_workflow + + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", "remote-workflow", "some_module.example_workflow:v1", "--x", "10"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + # Verify fetch_workflow was called with the correct arguments (project, domain, name, version) + mock_remote_instance.fetch_workflow.assert_called_once() + call_args = mock_remote_instance.fetch_workflow.call_args[0] + # Should be called with 4 args when version is specified + assert len(call_args) == 4 + assert call_args[2] == "some_module.example_workflow" + assert call_args[3] == "v1" + mock_run_remote.assert_called_once()