|
21 | 21 | TypeVar,
|
22 | 22 | Union,
|
23 | 23 | get_type_hints,
|
| 24 | + TYPE_CHECKING, |
24 | 25 | )
|
25 | 26 |
|
26 | 27 | import grpclib.client
|
|
29 | 30 |
|
30 | 31 | from .casing import safe_snake_case
|
31 | 32 |
|
| 33 | +if TYPE_CHECKING: |
| 34 | + from grpclib._protocols import IProtoMessage |
| 35 | + |
32 | 36 | # Proto 3 data types
|
33 | 37 | TYPE_ENUM = "enum"
|
34 | 38 | TYPE_BOOL = "bool"
|
@@ -420,6 +424,7 @@ class Message(ABC):
|
420 | 424 | register the message fields which get used by the serializers and parsers
|
421 | 425 | to go between Python, binary and JSON protobuf message representations.
|
422 | 426 | """
|
| 427 | + |
423 | 428 | _serialized_on_wire: bool
|
424 | 429 | _unknown_fields: bytes
|
425 | 430 | _group_map: Dict[str, dict]
|
@@ -705,7 +710,7 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
|
705 | 710 | for field in dataclasses.fields(self):
|
706 | 711 | meta = FieldMetadata.get(field)
|
707 | 712 | v = getattr(self, field.name)
|
708 |
| - cased_name = casing(field.name).rstrip("_") # type: ignore |
| 713 | + cased_name = casing(field.name).rstrip("_") # type: ignore |
709 | 714 | if meta.proto_type == "message":
|
710 | 715 | if isinstance(v, datetime):
|
711 | 716 | if v != DATETIME_ZERO:
|
@@ -741,7 +746,7 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
|
741 | 746 | else:
|
742 | 747 | output[cased_name] = b64encode(v).decode("utf8")
|
743 | 748 | elif meta.proto_type == TYPE_ENUM:
|
744 |
| - enum_values = list(self._cls_for(field)) # type: ignore |
| 749 | + enum_values = list(self._cls_for(field)) # type: ignore |
745 | 750 | if isinstance(v, list):
|
746 | 751 | output[cased_name] = [enum_values[e].name for e in v]
|
747 | 752 | else:
|
@@ -902,6 +907,7 @@ class _WrappedMessage(Message):
|
902 | 907 | Google protobuf wrapper types base class. JSON representation is just the
|
903 | 908 | value itself.
|
904 | 909 | """
|
| 910 | + |
905 | 911 | value: Any
|
906 | 912 |
|
907 | 913 | def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
|
@@ -982,23 +988,23 @@ def __init__(self, channel: grpclib.client.Channel) -> None:
|
982 | 988 | self.channel = channel
|
983 | 989 |
|
984 | 990 | async def _unary_unary(
|
985 |
| - self, route: str, request_type: Type, response_type: Type[T], request: Any |
| 991 | + self, route: str, request: "IProtoMessage", response_type: Type[T] |
986 | 992 | ) -> T:
|
987 | 993 | """Make a unary request and return the response."""
|
988 | 994 | async with self.channel.request(
|
989 |
| - route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type |
| 995 | + route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type |
990 | 996 | ) as stream:
|
991 | 997 | await stream.send_message(request, end=True)
|
992 | 998 | response = await stream.recv_message()
|
993 | 999 | assert response is not None
|
994 | 1000 | return response
|
995 | 1001 |
|
996 | 1002 | async def _unary_stream(
|
997 |
| - self, route: str, request_type: Type, response_type: Type[T], request: Any |
| 1003 | + self, route: str, request: "IProtoMessage", response_type: Type[T] |
998 | 1004 | ) -> AsyncGenerator[T, None]:
|
999 | 1005 | """Make a unary request and return the stream response iterator."""
|
1000 | 1006 | async with self.channel.request(
|
1001 |
| - route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type |
| 1007 | + route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type |
1002 | 1008 | ) as stream:
|
1003 | 1009 | await stream.send_message(request, end=True)
|
1004 | 1010 | async for message in stream:
|
|
0 commit comments