Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions examples/advanced/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import flyte

env = flyte.TaskEnvironment(name="events")


@env.task
async def next_task(value: int) -> int:
return value + 1


@env.task
async def my_task(x: int) -> int:
event1 = await flyte.new_event.aio(
"my_event",
scope="run",
prompt="Is it ok to continue?",
data_type=bool,
)
event2 = await flyte.new_event.aio(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should allow users to make the prompt fancy, like html etc (reports)
And then we should also have an even type that is a webhook invocation with a callback to signal

"proceed_event",
scope="run",
prompt="What should I add to x?",
data_type=int,
)
event3 = await flyte.new_event.aio(
"final_event",
scope="run",
prompt="What should I return if the first event was negative?",
data_type=int,
)
result = await event1.wait.aio()
if result:
print("Event signaled positive response, proceeding to next_task", flush=True)
result2 = await event2.wait.aio()
return await next_task(x + result2)
else:
print("Event signaled negative response, returning -1", flush=True)
result3 = await event3.wait.aio()
return result3


if __name__ == "__main__":
flyte.init()

r = flyte.run(my_task, x=10)
print(r.url)
print(r.outputs())

import flyte.remote as remote

while not (remote_event := remote.Event.get("my_event", r.name)):
time.sleep(10)

remote_event.signal(True)
2 changes: 2 additions & 0 deletions src/flyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._custom_context import custom_context, get_custom_context
from ._deploy import build_images, deploy
from ._environment import Environment
from ._event import new_event
from ._excepthook import custom_excepthook
from ._group import group
from ._image import Image
Expand Down Expand Up @@ -100,6 +101,7 @@ def version() -> str:
"init_from_config",
"logger",
"map",
"new_event",
"run",
"trace",
"version",
Expand Down
108 changes: 108 additions & 0 deletions src/flyte/_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import typing
from dataclasses import dataclass
from typing import Generic, Literal, Type

import rich.repr

from flyte.syncify import syncify

EventScope = Literal["task", "run", "action"]

EventType = typing.TypeVar("EventType", bool, int, float, str)


@rich.repr.auto
@dataclass
class _Event(Generic[EventType]):
"""
An event that can be awaited in a Run. Events can be used to pause Run until an external signal is received.

Examples:

```python
import flyte

env = flyte.TaskEnvironment(name="events")

@env.task
async def my_task() -> Optional[int]:
event = await flyte.new_event(name="my_event", scope="run", prompt="Is it ok to continue?", data_type=bool)
result = await event.wait()
if result:
return 42
else:
return None
```
"""

name: str
# TODO restrict scope to action only right now
scope: EventScope = "run"
# TODO Support prompt as html
prompt: str = "Approve?"
data_type: Type[EventType] = bool # type: ignore[assignment]
description: str = ""

def __post_init__(self):
valid_types = (bool, int, float, str)
if self.data_type not in valid_types:
raise TypeError(f"Invalid data_type {self.data_type}. Must be one of {valid_types}.")

@syncify
async def wait(self) -> EventType:
"""
Await the event to be signaled.

:return: The payload associated with the event when it is signaled.
"""
from flyte._context import internal_ctx

ctx = internal_ctx()
if ctx.is_task_context():
# If we are in a task context, that implies we are executing a Run.
# In this scenario, we should submit the task to the controller.
# We will also check if we are not initialized, It is not expected to be not initialized
from ._internal.controllers import get_controller

controller = get_controller()
result = await controller.wait_for_event(self)
return result
else:
raise RuntimeError("Events can only be awaited within a task context.")


@syncify
async def new_event(
name: str,
/,
scope: EventScope = "run",
prompt: str = "Approve?",
data_type: Type[EventType] = bool, # type: ignore[assignment]
description: str = "",
) -> _Event:
"""
Create an event that can be awaited in a workflow. Events can be used to pause workflow execution until
an external signal is received.

:param name: Name of the event
:param scope: Scope of the event - "task", "run", or "action"
:param prompt: Prompt message for the event
:param data_type: Data type of the event payload
:param description: Description of the event
:return: An instance of _Event representing the created event
"""
event = _Event(name=name, scope=scope, prompt=prompt, data_type=data_type, description=description)
from flyte._context import internal_ctx

ctx = internal_ctx()
if ctx.is_task_context():
# If we are in a task context, that implies we are executing a Run.
# In this scenario, we should submit the task to the controller.
# We will also check if we are not initialized, It is not expected to be not initialized
from ._internal.controllers import get_controller

controller = get_controller()
await controller.register_event(event)
else:
pass
return event
17 changes: 17 additions & 0 deletions src/flyte/_internal/controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ async def record_trace(self, info: TraceInfo):
"""
...

async def register_event(self, event: Any):
"""
Register an event that can be awaited. This is used to register events that can pause execution
until an external signal is received.
:param event: Event object to register
:return:
"""
...

async def wait_for_event(self, event: Any) -> Any:
"""
Wait for an event to be signaled. This will block until the event receives data.
:param event: Event object to wait for
:return: The payload associated with the event when it is signaled
"""
...

async def stop(self):
"""
Stops the engine and should be called when the engine is no longer needed.
Expand Down
56 changes: 56 additions & 0 deletions src/flyte/_internal/controllers/_local_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class LocalController:
def __init__(self):
logger.debug("LocalController init")
self._runner_map: dict[str, _TaskRunner] = {}
self._registered_events: dict[str, Any] = {}

@log
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
Expand Down Expand Up @@ -237,3 +238,58 @@ async def submit_task_ref(self, _task: TaskDetails, max_inline_io_bytes: int, *a
raise flyte.errors.ReferenceTaskError(
f"Reference tasks cannot be executed locally, only remotely. Found remote task {_task.name}"
)

async def register_event(self, event: Any):
"""
Register an event that can be awaited. Stores the event for later retrieval.

:param event: Event object to register
"""
from flyte._event import _Event

if not isinstance(event, _Event):
raise TypeError(f"Expected _Event, got {type(event)}")

logger.debug(f"Registering event: {event.name} with scope: {event.scope}")
self._registered_events[event.name] = event

async def wait_for_event(self, event: Any) -> Any:
"""
Wait for an event to be signaled. Uses rich library to prompt the user for input.

:param event: Event object to wait for
:return: The payload associated with the event when it is signaled
"""
from rich.console import Console
from rich.prompt import Confirm, Prompt

from flyte._event import _Event

if not isinstance(event, _Event):
raise TypeError(f"Expected _Event, got {type(event)}")

logger.info(f"Waiting for event: {event.name}")

console = Console()
console.print(f"\n[bold cyan]Event:[/bold cyan] {event.name}")
if event.description:
console.print(f"[dim]{event.description}[/dim]")

# Handle different data types
if event.data_type is bool:
result = Confirm.ask(event.prompt, console=console)
elif event.data_type in (int, float, str):
# For int, float, str - use the same prompt with type conversion
while True:
try:
value = Prompt.ask(event.prompt, console=console)
result = event.data_type(value)
break
except ValueError:
type_name = event.data_type.__name__
console.print(f"[red]Please enter a valid {type_name}[/red]")
else:
raise ValueError(f"Unsupported data type {event.data_type}")

logger.debug(f"Event {event.name} received value: {result}")
return result
21 changes: 21 additions & 0 deletions src/flyte/_internal/controllers/remote/_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,24 @@ async def submit_task_ref(self, _task: TaskDetails, *args, **kwargs) -> Any:
task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
return await self._submit_task_ref(task_call_seq, _task, *args, **kwargs)

async def register_event(self, event: Any):
"""
Register an event that can be awaited.

TODO: Implement remote event registration

:param event: Event object to register
"""
raise NotImplementedError("Remote event registration is not yet implemented")

async def wait_for_event(self, event: Any) -> Any:
"""
Wait for an event to be signaled.

TODO: Implement remote event waiting

:param event: Event object to wait for
:return: The payload associated with the event when it is signaled
"""
raise NotImplementedError("Remote event waiting is not yet implemented")
27 changes: 27 additions & 0 deletions src/flyte/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,30 @@ class OnlyAsyncIOSupportedError(RuntimeUserError):

def __init__(self, message: str):
super().__init__("OnlyAsyncIOSupportedError", message, "user")


class EventAlreadyExistsInScopeError(RuntimeUserError):
"""
This error is raised when the user tries to create an event that already exists in the given scope.
"""

def __init__(self, message: str):
super().__init__("EventAlreadyExistsInScopeError", message, "user")


class EventNotFoundError(RuntimeUserError):
"""
This error is raised when the user tries to access an event that does not exist.
"""

def __init__(self, message: str):
super().__init__("EventNotFoundError", message, "user")


class EventScopeRequiredError(RuntimeUserError):
"""
This error is raised when the user tries to access an event without specifying the scope.
"""

def __init__(self, message: str):
super().__init__("EventScopeRequiredError", message, "user")
Loading
Loading