Skip to content

Commit 2c68918

Browse files
committed
Fix type specification and add explanation
1 parent d1cb62b commit 2c68918

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

shiny/reactive.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,28 @@ async def __call__(self) -> T:
191191
return await self.get_value()
192192

193193

194+
# Note that the specified return type of calc() isn't exactly the same as the actual
195+
# returned object -- the former specifes a Callable that takes a CalcFunction[T], and
196+
# the latter is a Callable that takes CalcFunction[T] | CalcFunctionAsync[T]. Both are
197+
# technically correct, since the CalcFunction's T encompasses both "regular" types V as
198+
# well as Awatiable[V]. (We're using V to represent a generic type that is NOT itself
199+
# Awaitable.) So if the T represents an Awaitable[V], then the type checker knows that
200+
# the returned function will return a Calc[Awaitable[V]].
201+
#
202+
# However, if the calc() function is specified to return a Callable that takes
203+
# CalcFunction[T] | CalcFunctionAsync[T], then if a CalcFunctionAsync is passed in, the
204+
# type check will not know that the returned Calc object is a Calc[Awaitable[V]]. It
205+
# will think that it's a [Calc[V]]. Then the type checker will think that the returned
206+
# Calc object is not async when it actually is.
207+
#
208+
# To work around this, we say that calc() returns a Callable that takes a
209+
# CalcFunction[T], instead of the union type. We're sort of tricking the type checker
210+
# twice: once here, and once when we return a Calc object (which has a synchronous
211+
# __call__ method) or CalcAsync object (which has an async __call__ method), and it
212+
# works out.
194213
def calc(
195214
*, session: Union[MISSING_TYPE, "Session", None] = MISSING
196-
) -> Callable[[Union[CalcFunction[T], CalcFunctionAsync[T]]], Calc[T]]:
215+
) -> Callable[[CalcFunction[T]], Calc[T]]:
197216
def create_calc(fn: Union[CalcFunction[T], CalcFunctionAsync[T]]) -> Calc[T]:
198217
if inspect.iscoroutinefunction(fn):
199218
fn = cast(CalcFunctionAsync[T], fn)

0 commit comments

Comments
 (0)