diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py index 10e533135443..bdae61ef5bfc 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py @@ -9,11 +9,12 @@ from azure.core.tracing.decorator import distributed_trace from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.paging import ItemPaged from ._chat_thread_client import ChatThreadClient from ._shared.user_credential import CommunicationTokenCredential from ._generated import AzureCommunicationChatService -from ._generated.models import CreateChatThreadRequest +from ._generated.models import CreateChatThreadRequest, ChatThreadItem from ._models import ChatThreadProperties, CreateChatThreadResult from ._utils import ( # pylint: disable=unused-import _to_utc_datetime, @@ -25,7 +26,6 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports from datetime import datetime - from azure.core.paging import ItemPaged class ChatClient(object): # pylint: disable=client-accepts-api-version-keyword @@ -170,13 +170,12 @@ def create_chat_thread( create_chat_thread_result.chat_thread ) - create_chat_thread_result = CreateChatThreadResult(chat_thread=chat_thread_properties, errors=errors) + result = CreateChatThreadResult(chat_thread=chat_thread_properties, errors=errors) - return create_chat_thread_result + return result @distributed_trace - def list_chat_threads(self, **kwargs): - # type: (...) -> ItemPaged[ChatThreadItem] + def list_chat_threads(self, **kwargs: Any) -> ItemPaged[ChatThreadItem]: """Gets the list of chat threads of a user. :keyword int results_per_page: The maximum number of chat threads returned per page. diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py index 8d63a6d35b0a..5dbf5551ee7c 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py @@ -9,6 +9,7 @@ from azure.core.tracing.decorator import distributed_trace from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.paging import ItemPaged from ._shared.user_credential import CommunicationTokenCredential from ._shared.models import CommunicationIdentifier @@ -22,6 +23,8 @@ UpdateChatThreadRequest, ChatMessageType, SendChatMessageResult, + ChatError, + CommunicationIdentifierModel, ) from ._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties @@ -198,9 +201,8 @@ def send_read_receipt( @distributed_trace def list_read_receipts( - self, **kwargs # type: Any - ): - # type: (...) -> ItemPaged[ChatMessageReadReceipt] + self, **kwargs: Any + ) -> ItemPaged[ChatMessageReadReceipt]: """Gets read receipts for a thread. :keyword int results_per_page: The maximum number of chat message read receipts to be returned per page. @@ -353,9 +355,8 @@ def get_message( @distributed_trace def list_messages( - self, **kwargs # type: Any - ): - # type: (...) -> ItemPaged[ChatMessage] + self, **kwargs: Any + ) -> ItemPaged[ChatMessage]: """Gets a list of messages from a thread. :keyword int results_per_page: The maximum number of messages to be returned per page. @@ -460,9 +461,8 @@ def delete_message( @distributed_trace def list_participants( - self, **kwargs # type: Any - ): - # type: (...) -> ItemPaged[ChatParticipant] + self, **kwargs: Any + ) -> ItemPaged[ChatParticipant]: """Gets the participants of a thread. :keyword int results_per_page: The maximum number of participants to be returned per page. @@ -495,10 +495,9 @@ def list_participants( @distributed_trace def add_participants( self, - thread_participants, # type: List[ChatParticipant] - **kwargs # type: Any - ): - # type: (...) -> List[Tuple[ChatParticipant, ChatError]] + thread_participants: List[ChatParticipant], + **kwargs: Any + ) -> List[Tuple[Optional[ChatParticipant], ChatError]]: """Adds thread participants to a thread. If participants already exist, no change occurs. If all participants are added successfully, then an empty list is returned; @@ -562,9 +561,10 @@ def remove_participant( if not identifier: raise ValueError("identifier cannot be None.") + participant_model = CommunicationIdentifierModel(**serialize_identifier(identifier)) return self._client.chat_thread.remove_chat_participant( chat_thread_id=self._thread_id, - participant_communication_identifier=serialize_identifier(identifier), + participant_communication_identifier=participant_model, **kwargs ) diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py index 6b08da1d07ea..0e8b9a0869a8 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import Dict, Any, TYPE_CHECKING +from typing import Dict, Any, TYPE_CHECKING, cast from ._shared.models import ( CommunicationIdentifier, @@ -28,10 +28,11 @@ def serialize_identifier(identifier): :rtype: ~azure.communication.chat._generated.models.CommunicationIdentifierModel """ try: - request_model = {"raw_id": identifier.raw_id} + request_model: Dict[str, Any] = {"raw_id": identifier.raw_id} if identifier.kind and identifier.kind != CommunicationIdentifierKind.UNKNOWN: - request_model[identifier.kind] = dict(identifier.properties) + kind_str = cast(str, identifier.kind) + request_model[kind_str] = dict(identifier.properties) return request_model except AttributeError: raise TypeError( # pylint: disable=raise-missing-from @@ -52,7 +53,7 @@ def deserialize_identifier(identifier_model): raw_id = identifier_model.raw_id if identifier_model.communication_user: - return CommunicationUserIdentifier(raw_id, raw_id=raw_id) + return CommunicationUserIdentifier(identifier_model.communication_user.id, raw_id=raw_id) if identifier_model.phone_number: return PhoneNumberIdentifier(identifier_model.phone_number.value, raw_id=raw_id) if identifier_model.microsoft_teams_user: @@ -62,4 +63,4 @@ def deserialize_identifier(identifier_model): is_anonymous=identifier_model.microsoft_teams_user.is_anonymous, cloud=identifier_model.microsoft_teams_user.cloud, ) - return UnknownIdentifier(raw_id) + return UnknownIdentifier(raw_id or "") diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py index c4dc8b52b098..a63499d41ed5 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py @@ -4,6 +4,10 @@ # license information. # -------------------------------------------------------------------------- +from typing import Any, Dict, List, Optional, Tuple +from ._models import ChatParticipant +from ._generated.models import ChatError + def _to_utc_datetime(value): return value.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -23,8 +27,7 @@ class CommunicationErrorResponseConverter(object): """ @classmethod - def convert(cls, participants, chat_errors): - # type: (...) -> list[(ChatThreadParticipant, ChatError)] + def convert(cls, participants: List[ChatParticipant], chat_errors: Optional[List[ChatError]]) -> List[Tuple[Optional[ChatParticipant], ChatError]]: """ Util function to convert AddChatParticipantsResult. @@ -40,13 +43,12 @@ def convert(cls, participants, chat_errors): :rtype: list[(~azure.communication.chat.ChatParticipant, ~azure.communication.chat.ChatError)] """ - def create_dict(participants): - # type: (...) -> Dict(str, ChatThreadParticipant) + def create_dict(participants: List[ChatParticipant]) -> Dict[str, ChatParticipant]: """ Create dictionary of id -> ChatParticipant - :param list participants: list of ChatThreadParticipant - :return: Dictionary of id -> ChatThreadParticipant + :param participants: list of ChatParticipant + :return: Dictionary of id -> ChatParticipant :rtype: dict """ result = {} @@ -56,11 +58,13 @@ def create_dict(participants): _thread_participants_dict = create_dict(participants=participants) - failed_chat_thread_participants = [] + failed_chat_thread_participants: List[Tuple[Optional[ChatParticipant], ChatError]] = [] if chat_errors is not None: for chat_error in chat_errors: - _thread_participant = _thread_participants_dict.get(chat_error.target) - failed_chat_thread_participants.append((_thread_participant, chat_error)) + target = chat_error.target + if target is not None: + _thread_participant = _thread_participants_dict.get(target) + failed_chat_thread_participants.append((_thread_participant, chat_error)) return failed_chat_thread_participants diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py index a4ac904ba28a..acd002559d47 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py @@ -26,6 +26,7 @@ SendChatMessageResult, ChatMessageType, ChatError, + CommunicationIdentifierModel, ) from .._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties from .._shared.models import CommunicationIdentifier @@ -122,7 +123,7 @@ async def get_properties(self, **kwargs) -> ChatThreadProperties: return ChatThreadProperties._from_generated(chat_thread) # pylint:disable=protected-access @distributed_trace_async - async def update_topic(self, topic: str = None, **kwargs) -> None: + async def update_topic(self, topic: Optional[str] = None, **kwargs: Any) -> None: """Updates a thread's properties. :param topic: Thread topic. If topic is not specified, the update will succeed but @@ -233,7 +234,7 @@ async def send_typing_notification(self, *, sender_display_name: Optional[str] = ) @distributed_trace_async - async def send_message(self, content: str, *, metadata: Dict[str, str] = None, **kwargs) -> SendChatMessageResult: + async def send_message(self, content: str, *, metadata: Optional[Dict[str, str]] = None, **kwargs: Any) -> SendChatMessageResult: """Sends a message to a thread. :param content: Required. Chat message content. @@ -344,7 +345,7 @@ def list_messages(self, **kwargs: Any) -> AsyncItemPaged[ChatMessage]: @distributed_trace_async async def update_message( - self, message_id: str, content: str = None, *, metadata: Dict[str, str] = None, **kwargs + self, message_id: str, content: Optional[str] = None, *, metadata: Optional[Dict[str, str]] = None, **kwargs: Any ) -> None: """Updates a message. diff --git a/sdk/communication/azure-communication-chat/pyproject.toml b/sdk/communication/azure-communication-chat/pyproject.toml index 05ee5668ed6a..343877d67195 100644 --- a/sdk/communication/azure-communication-chat/pyproject.toml +++ b/sdk/communication/azure-communication-chat/pyproject.toml @@ -1,4 +1,4 @@ [tool.azure-sdk-build] -mypy = false +mypy = true pyright = false type_check_samples = false