Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/pyodide/internal/workers-api/src/workers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
Request,
RequestInitCfProperties,
Response,
RetryConfig,
RollbackConfig,
RollbackStep,
StepConfig,
UndoHandler,
WorkerEntrypoint,
WorkflowEntrypoint,
fetch,
Expand Down Expand Up @@ -42,6 +47,11 @@
"Request",
"RequestInitCfProperties",
"Response",
"RetryConfig",
"RollbackConfig",
"RollbackStep",
"StepConfig",
"UndoHandler",
"WorkerEntrypoint",
"WorkflowEntrypoint",
"env",
Expand Down
273 changes: 272 additions & 1 deletion src/pyodide/internal/workers-api/src/workers/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@
from enum import StrEnum
from http import HTTPMethod, HTTPStatus
from types import LambdaType
from typing import Any, Never, Protocol, TypedDict, Unpack
from typing import (
Any,
Awaitable,
Callable,
Generic,
Literal,
Never,
Protocol,
TypedDict,
TypeVar,
Unpack,
)

# Get globals modules and import function from the entrypoint-helper
import _pyodide_entrypoint_helper
Expand All @@ -38,6 +49,34 @@
from pyodide.http import pyfetch
from workers.workflows import NonRetryableError

# Type definitions for workflow steps
T = TypeVar("T")


class RetryConfig(TypedDict, total=False):
"""Configuration for step retry behavior."""

limit: int
delay: str | int
backoff: Literal["constant", "linear", "exponential"]


class StepConfig(TypedDict, total=False):
"""Configuration for workflow step execution."""

retries: RetryConfig
timeout: str | int


class RollbackConfig(TypedDict, total=False):
"""Configuration for workflow rollback behavior at instance creation."""

continue_on_error: bool


# Undo function signature: (err, value) -> Awaitable[None]
UndoHandler = Callable[[Exception | None, T], Awaitable[None]]


class Context(Protocol):
def waitUntil(self, other: Awaitable[Any]) -> None: ...
Expand Down Expand Up @@ -1125,6 +1164,115 @@ def wrapper(*args, **kwargs):
return wrapper


class RollbackStep(Generic[T]):
"""
Wrapper returned by @step.with_rollback decorator.

Delegates to the engine's withRollback for durable undo stack management.

Usage:
# Pattern A: Chained decorator (preferred - keeps do/undo together)
@step.with_rollback("save to db")
async def save():
return await db.insert(data)

@save.undo
async def _(error, record_id):
await db.delete(record_id)

record_id = await save()

# Pattern B: Parameter (for reusable undo handlers)
@step.with_rollback("save to db", undo=generic_delete)
async def save():
return await db.insert(data)
"""

def __init__(
self,
step_wrapper: "_WorkflowStepWrapper",
name: str,
do_fn: Callable[..., Awaitable[T]],
*,
undo: UndoHandler[T] | None = None,
depends: list[Callable[..., Awaitable[Any]]] | None = None,
concurrent: bool = False,
config: StepConfig | None = None,
undo_config: StepConfig | None = None,
):
self._step_wrapper = step_wrapper
self._name = name
self._do_fn = do_fn
self._undo_handler = undo
self._depends = depends
self._concurrent = concurrent
self._config = config
self._undo_config = undo_config
self._step_name = name # For dependency resolution

def undo(
self, fn_or_config: UndoHandler[T] | StepConfig | None = None
) -> UndoHandler[T] | Callable[[UndoHandler[T]], UndoHandler[T]]:
"""
Decorator to register an undo/compensation function for this step.

The undo function receives (error, value) where value is the result
of the do function.

Args:
fn_or_config: Either the undo function directly (@fn.undo) or
a StepConfig dict (@fn.undo(config={...}))
"""
# Support @fn.undo (no parens)
if callable(fn_or_config):
self._undo_handler = fn_or_config
return fn_or_config

# Support @fn.undo() or @fn.undo(config={...})
config = fn_or_config

def decorator(fn: UndoHandler[T]) -> UndoHandler[T]:
self._undo_handler = fn
if config is not None:
self._undo_config = config
return fn

return decorator

async def __call__(self) -> T:
"""Execute the step via engine's withRollback for durable undo stack."""
if self._undo_handler is None:
raise ValueError(
f"Step '{self._name}' requires an undo handler. "
f"Add @{self._do_fn.__name__}.undo or pass undo= parameter."
)

# Resolve dependencies (same pattern as step.do)
if self._concurrent:
results = await gather(
*[
self._step_wrapper._resolve_dependency(dep)
for dep in self._depends or []
]
)
else:
results = [
await self._step_wrapper._resolve_dependency(dep)
for dep in self._depends or []
]
python_results = [python_from_rpc(r) for r in results]

return await _withRollback_call(
self._step_wrapper,
self._name,
self._config,
self._undo_config,
self._do_fn,
self._undo_handler,
*python_results,
)


class _WorkflowStepWrapper:
def __init__(self, js_step):
self._js_step = js_step
Expand Down Expand Up @@ -1169,6 +1317,67 @@ def wait_for_event(self, name, event_type, /, timeout="24 hours"):
),
)

def with_rollback(
self,
name: str,
*,
undo: UndoHandler[T] | None = None,
depends: list[Callable[..., Awaitable[Any]]] | None = None,
concurrent: bool = False,
config: StepConfig | None = None,
undo_config: StepConfig | None = None,
) -> Callable[[Callable[..., Awaitable[T]]], RollbackStep[T]]:
"""
Decorator for step with rollback/compensation support (saga pattern).

Returns a callable wrapper that allows attaching an .undo decorator for
compensation logic. Undo functions execute automatically in LIFO order
when the workflow throws an uncaught error (if rollback config is enabled
at instance creation).

Args:
name: Step name (up to 256 chars)
undo: Undo handler, or use @decorated_fn.undo
depends: Steps this depends on (DAG pattern)
concurrent: Run dependencies in parallel
config: Retry/timeout config for do()
undo_config: Retry/timeout config for undo()

Raises:
ValueError: If no undo handler provided via parameter or decorator

Usage:
# Pattern A: Chained decorator (preferred)
@step.with_rollback("save to db")
async def save():
return await db.insert(data)

@save.undo
async def _(error, record_id):
await db.delete(record_id)

record_id = await save()

# Pattern B: Parameter (for reusable undo handlers)
@step.with_rollback("save to db", undo=generic_delete)
async def save():
return await db.insert(data)
"""

def decorator(func: Callable[..., Awaitable[T]]) -> RollbackStep[T]:
return RollbackStep(
self,
name,
func,
undo=undo,
depends=depends,
concurrent=concurrent,
config=config,
undo_config=undo_config,
)

return decorator

async def _resolve_dependency(self, dep):
if dep._step_name in self._memoized_dependencies:
return self._memoized_dependencies[dep._step_name]
Expand Down Expand Up @@ -1211,6 +1420,68 @@ async def _closure():
return result


async def _withRollback_call(
entrypoint, name, config, undo_config, do_fn, undo_fn, *dep_results
):
"""Call the engine's withRollback with Python callbacks wrapped for JS."""

async def _closure():
async def _do_callback():
result = do_fn(*dep_results)
if inspect.iscoroutine(result):
result = await result
return to_js(result, dict_converter=Object.fromEntries)

async def _undo_callback(js_err, js_value):
py_err = None
if js_err is not None:
py_err = (
_from_js_error(js_err) if hasattr(js_err, "message") else js_err
)

py_value = python_from_rpc(js_value)

result = undo_fn(py_err, py_value)
if inspect.iscoroutine(result):
await result

handler = {"do": _do_callback}
if undo_fn is not None:
handler["undo"] = _undo_callback

js_handler = to_js(handler, dict_converter=Object.fromEntries)

js_config = None
if config is not None or undo_config is not None:
config_dict = dict(config) if config else {}
if undo_config is not None:
config_dict["undoConfig"] = undo_config
js_config = to_js(config_dict, dict_converter=Object.fromEntries)

try:
if js_config is None:
result = await entrypoint._js_step.withRollback(name, js_handler)
else:
result = await entrypoint._js_step.withRollback(
name, js_handler, js_config
)

return python_from_rpc(result)
except Exception as exc:
raise _from_js_error(exc) from exc

task = create_task(_closure())
entrypoint._in_flight[name] = task

try:
result = await task
entrypoint._memoized_dependencies[name] = result
finally:
del entrypoint._in_flight[name]

return result


def _wrap_subclass(cls):
# Override the class __init__ so that we can wrap the `env` in the constructor.
original_init = cls.__init__
Expand Down
2 changes: 2 additions & 0 deletions src/workerd/server/tests/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ py_wd_test("python-rpc")

py_wd_test("workflow-entrypoint")

py_wd_test("workflow-rollback")

py_wd_test("vendor_dir_compat_flag")

py_wd_test("default-class-with-legacy-global-handlers")
Expand Down
Loading
Loading