@@ -191,9 +191,28 @@ async def __call__(self) -> T:
191191 return await self .get_value ()
192192
193193
194+ # Note that the specified return type of calc() isn't exactly the same as the actual
195+ # returned object -- the former specifes a Callable that takes a CalcFunction[T], and
196+ # the latter is a Callable that takes CalcFunction[T] | CalcFunctionAsync[T]. Both are
197+ # technically correct, since the CalcFunction's T encompasses both "regular" types V as
198+ # well as Awatiable[V]. (We're using V to represent a generic type that is NOT itself
199+ # Awaitable.) So if the T represents an Awaitable[V], then the type checker knows that
200+ # the returned function will return a Calc[Awaitable[V]].
201+ #
202+ # However, if the calc() function is specified to return a Callable that takes
203+ # CalcFunction[T] | CalcFunctionAsync[T], then if a CalcFunctionAsync is passed in, the
204+ # type check will not know that the returned Calc object is a Calc[Awaitable[V]]. It
205+ # will think that it's a [Calc[V]]. Then the type checker will think that the returned
206+ # Calc object is not async when it actually is.
207+ #
208+ # To work around this, we say that calc() returns a Callable that takes a
209+ # CalcFunction[T], instead of the union type. We're sort of tricking the type checker
210+ # twice: once here, and once when we return a Calc object (which has a synchronous
211+ # __call__ method) or CalcAsync object (which has an async __call__ method), and it
212+ # works out.
194213def calc (
195214 * , session : Union [MISSING_TYPE , "Session" , None ] = MISSING
196- ) -> Callable [[Union [ CalcFunction [T ], CalcFunctionAsync [ T ] ]], Calc [T ]]:
215+ ) -> Callable [[CalcFunction [T ]], Calc [T ]]:
197216 def create_calc (fn : Union [CalcFunction [T ], CalcFunctionAsync [T ]]) -> Calc [T ]:
198217 if inspect .iscoroutinefunction (fn ):
199218 fn = cast (CalcFunctionAsync [T ], fn )
0 commit comments