diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py index 10e533135443..51ee6ff72bc1 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py @@ -14,6 +14,7 @@ from ._shared.user_credential import CommunicationTokenCredential from ._generated import AzureCommunicationChatService from ._generated.models import CreateChatThreadRequest +from ._generated.models import ChatThreadItem from ._models import ChatThreadProperties, CreateChatThreadResult from ._utils import ( # pylint: disable=unused-import _to_utc_datetime, @@ -156,18 +157,18 @@ def create_chat_thread( create_thread_request = CreateChatThreadRequest(topic=topic, participants=participants) - create_chat_thread_result = self._client.chat.create_chat_thread( + generated_result = self._client.chat.create_chat_thread( create_chat_thread_request=create_thread_request, repeatability_request_id=idempotency_token, **kwargs ) errors = None - if hasattr(create_chat_thread_result, "invalid_participants"): + if hasattr(generated_result, "invalid_participants"): errors = CommunicationErrorResponseConverter.convert( - participants=thread_participants or [], chat_errors=create_chat_thread_result.invalid_participants + participants=thread_participants or [], chat_errors=generated_result.invalid_participants ) chat_thread_properties = ChatThreadProperties._from_generated( # pylint:disable=protected-access - create_chat_thread_result.chat_thread + generated_result.chat_thread ) create_chat_thread_result = CreateChatThreadResult(chat_thread=chat_thread_properties, errors=errors) @@ -197,7 +198,7 @@ def list_chat_threads(self, **kwargs): results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs) + return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs) # type: ignore[return-value] @distributed_trace def delete_chat_thread( diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py index 8d63a6d35b0a..6318478242f2 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py @@ -22,10 +22,11 @@ UpdateChatThreadRequest, ChatMessageType, SendChatMessageResult, + ChatError, ) from ._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties -from ._communication_identifier_serializer import serialize_identifier +from ._communication_identifier_serializer import serialize_identifier, identifier_to_generated_model from ._utils import CommunicationErrorResponseConverter from ._version import SDK_MONIKER @@ -221,7 +222,7 @@ def list_read_receipts( results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_read_receipts( + return self._client.chat_thread.list_chat_read_receipts( # type: ignore[return-value] self._thread_id, max_page_size=results_per_page, skip=skip, @@ -377,14 +378,14 @@ def list_messages( results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - a = self._client.chat_thread.list_chat_messages( + a = self._client.chat_thread.list_chat_messages( # type: ignore[return-value] self._thread_id, max_page_size=results_per_page, start_time=start_time, cls=lambda objs: [ChatMessage._from_generated(x) for x in objs], # pylint:disable=protected-access **kwargs ) - return a + return a # type: ignore[return-value] @distributed_trace def update_message( @@ -484,7 +485,7 @@ def list_participants( results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_participants( + return self._client.chat_thread.list_chat_participants( # type: ignore[return-value] self._thread_id, max_page_size=results_per_page, skip=skip, @@ -564,7 +565,7 @@ def remove_participant( return self._client.chat_thread.remove_chat_participant( chat_thread_id=self._thread_id, - participant_communication_identifier=serialize_identifier(identifier), + participant_communication_identifier=identifier_to_generated_model(identifier), **kwargs ) diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py index 6b08da1d07ea..0f0d1bfd130e 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py @@ -28,7 +28,7 @@ def serialize_identifier(identifier): :rtype: ~azure.communication.chat._generated.models.CommunicationIdentifierModel """ try: - request_model = {"raw_id": identifier.raw_id} + request_model: Dict[str, Any] = {"raw_id": identifier.raw_id} if identifier.kind and identifier.kind != CommunicationIdentifierKind.UNKNOWN: request_model[identifier.kind] = dict(identifier.properties) @@ -52,7 +52,7 @@ def deserialize_identifier(identifier_model): raw_id = identifier_model.raw_id if identifier_model.communication_user: - return CommunicationUserIdentifier(raw_id, raw_id=raw_id) + return CommunicationUserIdentifier(raw_id or "", raw_id=raw_id) if identifier_model.phone_number: return PhoneNumberIdentifier(identifier_model.phone_number.value, raw_id=raw_id) if identifier_model.microsoft_teams_user: @@ -62,4 +62,19 @@ def deserialize_identifier(identifier_model): is_anonymous=identifier_model.microsoft_teams_user.is_anonymous, cloud=identifier_model.microsoft_teams_user.cloud, ) - return UnknownIdentifier(raw_id) + return UnknownIdentifier(raw_id or "") + + +def identifier_to_generated_model(identifier): + # type: (CommunicationIdentifier) -> CommunicationIdentifierModel + """Convert a CommunicationIdentifier to a CommunicationIdentifierModel + + :param identifier: Identifier object + :type identifier: CommunicationIdentifier + :return: CommunicationIdentifierModel + :rtype: ~azure.communication.chat._generated.models.CommunicationIdentifierModel + """ + from ._generated.models import CommunicationIdentifierModel + + serialized = serialize_identifier(identifier) + return CommunicationIdentifierModel(**serialized) # type: ignore[misc] diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py index c4dc8b52b098..23cf5104f3a0 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py @@ -4,6 +4,12 @@ # license information. # -------------------------------------------------------------------------- +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from ._models import ChatParticipant + from ._generated.models import ChatError + def _to_utc_datetime(value): return value.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -24,7 +30,7 @@ class CommunicationErrorResponseConverter(object): @classmethod def convert(cls, participants, chat_errors): - # type: (...) -> list[(ChatThreadParticipant, ChatError)] + # type: (...) -> List[Tuple[ChatParticipant, ChatError]] """ Util function to convert AddChatParticipantsResult. @@ -41,7 +47,7 @@ def convert(cls, participants, chat_errors): """ def create_dict(participants): - # type: (...) -> Dict(str, ChatThreadParticipant) + # type: (...) -> Dict[str, ChatParticipant] """ Create dictionary of id -> ChatParticipant @@ -61,6 +67,7 @@ def create_dict(participants): if chat_errors is not None: for chat_error in chat_errors: _thread_participant = _thread_participants_dict.get(chat_error.target) - failed_chat_thread_participants.append((_thread_participant, chat_error)) + if _thread_participant is not None: + failed_chat_thread_participants.append((_thread_participant, chat_error)) return failed_chat_thread_participants diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py index 8761951a38da..349022c5bb6b 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py @@ -141,18 +141,18 @@ async def create_chat_thread(self, topic: str, **kwargs) -> CreateChatThreadResu create_thread_request = CreateChatThreadRequest(topic=topic, participants=participants) - create_chat_thread_result = await self._client.chat.create_chat_thread( + generated_result = await self._client.chat.create_chat_thread( create_chat_thread_request=create_thread_request, repeatability_request_id=idempotency_token, **kwargs ) errors = None - if hasattr(create_chat_thread_result, "invalid_participants"): + if hasattr(generated_result, "invalid_participants"): errors = CommunicationErrorResponseConverter.convert( - participants=thread_participants or [], chat_errors=create_chat_thread_result.invalid_participants + participants=thread_participants or [], chat_errors=generated_result.invalid_participants ) chat_thread = ChatThreadProperties._from_generated( # pylint:disable=protected-access - create_chat_thread_result.chat_thread + generated_result.chat_thread ) create_chat_thread_result = CreateChatThreadResult(chat_thread=chat_thread, errors=errors) @@ -181,7 +181,7 @@ def list_chat_threads(self, **kwargs: Any) -> AsyncItemPaged[ChatThreadItem]: results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs) + return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs) # type: ignore[return-value] @distributed_trace_async async def delete_chat_thread(self, thread_id: str, **kwargs) -> None: diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py index a4ac904ba28a..1eb55ef99408 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py @@ -29,7 +29,7 @@ ) from .._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties from .._shared.models import CommunicationIdentifier -from .._communication_identifier_serializer import serialize_identifier +from .._communication_identifier_serializer import serialize_identifier, identifier_to_generated_model from .._utils import CommunicationErrorResponseConverter from .._version import SDK_MONIKER @@ -122,7 +122,7 @@ async def get_properties(self, **kwargs) -> ChatThreadProperties: return ChatThreadProperties._from_generated(chat_thread) # pylint:disable=protected-access @distributed_trace_async - async def update_topic(self, topic: str = None, **kwargs) -> None: + async def update_topic(self, topic: Optional[str] = None, **kwargs) -> None: """Updates a thread's properties. :param topic: Thread topic. If topic is not specified, the update will succeed but @@ -196,7 +196,7 @@ def list_read_receipts(self, **kwargs: Any) -> AsyncItemPaged[ChatMessageReadRec results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_read_receipts( + return self._client.chat_thread.list_chat_read_receipts( # type: ignore[return-value] self._thread_id, max_page_size=results_per_page, skip=skip, @@ -233,7 +233,7 @@ async def send_typing_notification(self, *, sender_display_name: Optional[str] = ) @distributed_trace_async - async def send_message(self, content: str, *, metadata: Dict[str, str] = None, **kwargs) -> SendChatMessageResult: + async def send_message(self, content: str, *, metadata: Optional[Dict[str, str]] = None, **kwargs) -> SendChatMessageResult: """Sends a message to a thread. :param content: Required. Chat message content. @@ -334,7 +334,7 @@ def list_messages(self, **kwargs: Any) -> AsyncItemPaged[ChatMessage]: results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - return self._client.chat_thread.list_chat_messages( + return self._client.chat_thread.list_chat_messages( # type: ignore[return-value] self._thread_id, max_page_size=results_per_page, start_time=start_time, @@ -344,7 +344,7 @@ def list_messages(self, **kwargs: Any) -> AsyncItemPaged[ChatMessage]: @distributed_trace_async async def update_message( - self, message_id: str, content: str = None, *, metadata: Dict[str, str] = None, **kwargs + self, message_id: str, content: Optional[str] = None, *, metadata: Optional[Dict[str, str]] = None, **kwargs ) -> None: """Updates a message. @@ -426,7 +426,7 @@ def list_participants(self, **kwargs: Any) -> AsyncItemPaged[ChatParticipant]: results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_participants( + return self._client.chat_thread.list_chat_participants( # type: ignore[return-value] self._thread_id, max_page_size=results_per_page, skip=skip, @@ -498,7 +498,7 @@ async def remove_participant(self, identifier: CommunicationIdentifier, **kwargs return await self._client.chat_thread.remove_chat_participant( chat_thread_id=self._thread_id, - participant_communication_identifier=serialize_identifier(identifier), + participant_communication_identifier=identifier_to_generated_model(identifier), **kwargs ) diff --git a/sdk/communication/azure-communication-chat/pyproject.toml b/sdk/communication/azure-communication-chat/pyproject.toml index 05ee5668ed6a..343877d67195 100644 --- a/sdk/communication/azure-communication-chat/pyproject.toml +++ b/sdk/communication/azure-communication-chat/pyproject.toml @@ -1,4 +1,4 @@ [tool.azure-sdk-build] -mypy = false +mypy = true pyright = false type_check_samples = false