Skip to content

Commit 09e03f1

Browse files
authored
Outcome.intercept NoResult (#174)
The original implementation of `Outcome.intercept` used `Optional[T]` and thus could not distingish between no cached results and a cached result of None. This PR adds a singleton NoResult class to use instead of `None` in as the return value of the `interceptor` parameter of `Outcome.intercept`.
1 parent 6299fd8 commit 09e03f1

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

dbos/_core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
overload,
2222
)
2323

24-
from dbos._outcome import Immediate, Outcome, Pending
24+
from dbos._outcome import Immediate, NoResult, Outcome, Pending
2525

2626
from ._app_db import ApplicationDatabase, TransactionResultInternal
2727

@@ -719,7 +719,7 @@ def record_step_result(func: Callable[[], R]) -> R:
719719
finally:
720720
dbos._sys_db.record_operation_result(step_output)
721721

722-
def check_existing_result() -> Optional[str]:
722+
def check_existing_result() -> Union[NoResult, R]:
723723
ctx = assert_current_dbos_context()
724724
recorded_output = dbos._sys_db.check_operation_execution(
725725
ctx.workflow_id, ctx.function_id
@@ -734,14 +734,16 @@ def check_existing_result() -> Optional[str]:
734734
)
735735
raise deserialized_error
736736
elif recorded_output["output"] is not None:
737-
return recorded_output["output"]
737+
return cast(
738+
R, _serialization.deserialize(recorded_output["output"])
739+
)
738740
else:
739741
raise Exception("Output and error are both None")
740742
else:
741743
dbos.logger.debug(
742744
f"Running step, id: {ctx.function_id}, name: {attributes['name']}"
743745
)
744-
return None
746+
return NoResult()
745747

746748
stepOutcome = Outcome[R].make(functools.partial(func, *args, **kwargs))
747749
if retries_allowed:

dbos/_outcome.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
import time
55
from typing import Any, Callable, Coroutine, Optional, Protocol, TypeVar, Union, cast
66

7-
from . import _serialization
8-
97
T = TypeVar("T")
108
R = TypeVar("R")
119

1210

11+
class NoResult:
12+
_instance: Optional["NoResult"] = None
13+
__slots__ = ()
14+
15+
def __new__(cls, *args: Any, **kwargs: Any) -> "NoResult":
16+
if not cls._instance:
17+
cls._instance = super(NoResult, cls).__new__(cls, *args, **kwargs)
18+
return cls._instance
19+
20+
1321
# define Outcome protocol w/ common composition methods
1422
class Outcome(Protocol[T]):
1523

@@ -30,7 +38,9 @@ def retry(
3038
exceeded_retries: Callable[[int], BaseException],
3139
) -> "Outcome[T]": ...
3240

33-
def intercept(self, interceptor: Callable[[], Optional[str]]) -> "Outcome[T]": ...
41+
def intercept(
42+
self, interceptor: Callable[[], Union[NoResult, T]]
43+
) -> "Outcome[T]": ...
3444

3545
def __call__(self) -> Union[T, Coroutine[Any, Any, T]]: ...
3646

@@ -61,14 +71,14 @@ def wrap(
6171

6272
@staticmethod
6373
def _intercept(
64-
func: Callable[[], T], interceptor: Callable[[], Optional[str]]
74+
func: Callable[[], T], interceptor: Callable[[], Union[NoResult, T]]
6575
) -> T:
6676
intercepted = interceptor()
67-
return (
68-
cast(T, _serialization.deserialize(intercepted)) if intercepted else func()
69-
)
77+
return intercepted if not isinstance(intercepted, NoResult) else func()
7078

71-
def intercept(self, interceptor: Callable[[], Optional[str]]) -> "Immediate[T]":
79+
def intercept(
80+
self, interceptor: Callable[[], Union[NoResult, T]]
81+
) -> "Immediate[T]":
7282
return Immediate[T](lambda: Immediate._intercept(self._func, interceptor))
7383

7484
@staticmethod
@@ -157,16 +167,12 @@ def also(self, cm: contextlib.AbstractContextManager[Any, bool]) -> "Pending[T]"
157167
@staticmethod
158168
async def _intercept(
159169
func: Callable[[], Coroutine[Any, Any, T]],
160-
interceptor: Callable[[], Optional[str]],
170+
interceptor: Callable[[], Union[NoResult, T]],
161171
) -> T:
162172
intercepted = await asyncio.to_thread(interceptor)
163-
return (
164-
cast(T, _serialization.deserialize(intercepted))
165-
if intercepted
166-
else await func()
167-
)
173+
return intercepted if not isinstance(intercepted, NoResult) else await func()
168174

169-
def intercept(self, interceptor: Callable[[], Optional[str]]) -> "Pending[T]":
175+
def intercept(self, interceptor: Callable[[], Union[NoResult, T]]) -> "Pending[T]":
170176
return Pending[T](lambda: Pending._intercept(self._func, interceptor))
171177

172178
@staticmethod

0 commit comments

Comments
 (0)