@@ -663,30 +663,77 @@ def _map_key_value_types(
663663
664664 return ktype , vtype
665665
666- def _callable_type (self , method : d .MethodDescriptorProto ) -> str :
666+ def _callable_type (self , method : d .MethodDescriptorProto , is_async : bool = False ) -> str :
667+ module = "grpc.aio" if is_async else "grpc"
667668 if method .client_streaming :
668669 if method .server_streaming :
669- return self ._import ("grpc" , "StreamStreamMultiCallable" )
670+ return self ._import (module , "StreamStreamMultiCallable" )
670671 else :
671- return self ._import ("grpc" , "StreamUnaryMultiCallable" )
672+ return self ._import (module , "StreamUnaryMultiCallable" )
672673 else :
673674 if method .server_streaming :
674- return self ._import ("grpc" , "UnaryStreamMultiCallable" )
675+ return self ._import (module , "UnaryStreamMultiCallable" )
675676 else :
676- return self ._import ("grpc" , "UnaryUnaryMultiCallable" )
677+ return self ._import (module , "UnaryUnaryMultiCallable" )
677678
678- def _input_type (self , method : d .MethodDescriptorProto , use_stream_iterator : bool = True ) -> str :
679+ def _input_type (self , method : d .MethodDescriptorProto ) -> str :
679680 result = self ._import_message (method .input_type )
680- if use_stream_iterator and method .client_streaming :
681- result = f"{ self ._import ('collections.abc' , 'Iterator' )} [{ result } ]"
682681 return result
683682
684- def _output_type (self , method : d .MethodDescriptorProto , use_stream_iterator : bool = True ) -> str :
683+ def _servicer_input_type (self , method : d .MethodDescriptorProto ) -> str :
684+ result = self ._import_message (method .input_type )
685+ if method .client_streaming :
686+ # See write_grpc_async_hacks().
687+ result = f"_MaybeAsyncIterator[{ result } ]"
688+ return result
689+
690+ def _output_type (self , method : d .MethodDescriptorProto ) -> str :
685691 result = self ._import_message (method .output_type )
686- if use_stream_iterator and method .server_streaming :
687- result = f"{ self ._import ('collections.abc' , 'Iterator' )} [{ result } ]"
688692 return result
689693
694+ def _servicer_output_type (self , method : d .MethodDescriptorProto ) -> str :
695+ result = self ._import_message (method .output_type )
696+ if method .server_streaming :
697+ # Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
698+ # So both can be used in the covariant function return position.
699+ iterator = f"{ self ._import ('typing' , 'Iterator' )} [{ result } ]"
700+ aiterator = f"{ self ._import ('typing' , 'AsyncIterator' )} [{ result } ]"
701+ result = f"{ self ._import ('typing' , 'Union' )} [{ iterator } , { aiterator } ]"
702+ else :
703+ # Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
704+ # So both can be used in the covariant function return position.
705+ # Awaitable[Resp] is equivalent to async def.
706+ awaitable = f"{ self ._import ('typing' , 'Awaitable' )} [{ result } ]"
707+ result = f"{ self ._import ('typing' , 'Union' )} [{ result } , { awaitable } ]"
708+ return result
709+
710+ def write_grpc_async_hacks (self ) -> None :
711+ wl = self ._write_line
712+ # _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
713+ # So both can be used in the contravariant function parameter position.
714+ wl ("_T = {}('_T')" , self ._import ("typing" , "TypeVar" ))
715+ wl ("" )
716+ wl (
717+ "class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}):" ,
718+ self ._import ("typing" , "AsyncIterator" ),
719+ self ._import ("typing" , "Iterator" ),
720+ self ._import ("abc" , "ABCMeta" ),
721+ )
722+ with self ._indent ():
723+ wl ("..." )
724+ wl ("" )
725+
726+ # _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
727+ # So both can be used in the contravariant function parameter position.
728+ wl (
729+ "class _ServicerContext({}, {}): # type: ignore" ,
730+ self ._import ("grpc" , "ServicerContext" ),
731+ self ._import ("grpc.aio" , "ServicerContext" ),
732+ )
733+ with self ._indent ():
734+ wl ("..." )
735+ wl ("" )
736+
690737 def write_grpc_methods (self , service : d .ServiceDescriptorProto , scl_prefix : SourceCodeLocation ) -> None :
691738 wl = self ._write_line
692739 methods = [(i , m ) for i , m in enumerate (service .method ) if m .name not in PYTHON_RESERVED ]
@@ -701,20 +748,20 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
701748 with self ._indent ():
702749 wl ("self," )
703750 input_name = "request_iterator" if method .client_streaming else "request"
704- input_type = self ._input_type (method )
751+ input_type = self ._servicer_input_type (method )
705752 wl (f"{ input_name } : { input_type } ," )
706- wl ("context: {}," , self . _import ( "grpc" , "ServicerContext" ) )
753+ wl ("context: _ServicerContext," )
707754 wl (
708755 ") -> {}:{}" ,
709- self ._output_type (method ),
756+ self ._servicer_output_type (method ),
710757 " ..." if not self ._has_comments (scl ) else "" ,
711758 )
712759 if self ._has_comments (scl ):
713760 with self ._indent ():
714761 if not self ._write_comments (scl ):
715762 wl ("..." )
716763
717- def write_grpc_stub_methods (self , service : d .ServiceDescriptorProto , scl_prefix : SourceCodeLocation ) -> None :
764+ def write_grpc_stub_methods (self , service : d .ServiceDescriptorProto , scl_prefix : SourceCodeLocation , is_async : bool = False ) -> None :
718765 wl = self ._write_line
719766 methods = [(i , m ) for i , m in enumerate (service .method ) if m .name not in PYTHON_RESERVED ]
720767 if not methods :
@@ -723,10 +770,10 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
723770 for i , method in methods :
724771 scl = scl_prefix + [d .ServiceDescriptorProto .METHOD_FIELD_NUMBER , i ]
725772
726- wl ("{}: {}[" , method .name , self ._callable_type (method ))
773+ wl ("{}: {}[" , method .name , self ._callable_type (method , is_async = is_async ))
727774 with self ._indent ():
728- wl ("{}," , self ._input_type (method , False ))
729- wl ("{}," , self ._output_type (method , False ))
775+ wl ("{}," , self ._input_type (method ))
776+ wl ("{}," , self ._output_type (method ))
730777 wl ("]" )
731778 self ._write_comments (scl )
732779
@@ -743,17 +790,34 @@ def write_grpc_services(
743790 scl = scl_prefix + [i ]
744791
745792 # The stub client
746- wl (f"class { service .name } Stub:" )
793+ wl (
794+ "class {}Stub:" ,
795+ service .name ,
796+ )
747797 with self ._indent ():
748798 if self ._write_comments (scl ):
749799 wl ("" )
800+ # To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
801+ channel = f"{ self ._import ('typing' , 'Union' )} [{ self ._import ('grpc' , 'Channel' )} , { self ._import ('grpc.aio' , 'Channel' )} ]"
750802 wl (
751803 "def __init__(self, channel: {}) -> None: ..." ,
752- self . _import ( "grpc" , "Channel" ),
804+ channel
753805 )
754806 self .write_grpc_stub_methods (service , scl )
755807 wl ("" )
756808
809+ # The (fake) async stub client
810+ wl (
811+ "class {}AsyncStub:" ,
812+ service .name ,
813+ )
814+ with self ._indent ():
815+ if self ._write_comments (scl ):
816+ wl ("" )
817+ # No __init__ since this isn't a real class (yet), and requires manual casting to work.
818+ self .write_grpc_stub_methods (service , scl , is_async = True )
819+ wl ("" )
820+
757821 # The service definition interface
758822 wl (
759823 "class {}Servicer(metaclass={}):" ,
@@ -765,11 +829,13 @@ def write_grpc_services(
765829 wl ("" )
766830 self .write_grpc_methods (service , scl )
767831 wl ("" )
832+ server = self ._import ('grpc' , 'Server' )
833+ aserver = self ._import ('grpc.aio' , 'Server' )
768834 wl (
769835 "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ..." ,
770836 service .name ,
771837 service .name ,
772- self ._import ("grpc" , "Server" ) ,
838+ f" { self ._import ('typing' , 'Union' ) } [ { server } , { aserver } ]" ,
773839 )
774840 wl ("" )
775841
@@ -960,6 +1026,7 @@ def generate_mypy_grpc_stubs(
9601026 relax_strict_optional_primitives ,
9611027 grpc = True ,
9621028 )
1029+ pkg_writer .write_grpc_async_hacks ()
9631030 pkg_writer .write_grpc_services (fd .service , [d .FileDescriptorProto .SERVICE_FIELD_NUMBER ])
9641031
9651032 assert name == fd .name
0 commit comments