diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py index 10e533135443..f9a9e0ec6d0b 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_client.py @@ -3,17 +3,18 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from uuid import uuid4 from urllib.parse import urlparse from azure.core.tracing.decorator import distributed_trace from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.paging import ItemPaged from ._chat_thread_client import ChatThreadClient from ._shared.user_credential import CommunicationTokenCredential from ._generated import AzureCommunicationChatService -from ._generated.models import CreateChatThreadRequest +from ._generated.models import CreateChatThreadRequest, ChatThreadItem from ._models import ChatThreadProperties, CreateChatThreadResult from ._utils import ( # pylint: disable=unused-import _to_utc_datetime, @@ -156,18 +157,18 @@ def create_chat_thread( create_thread_request = CreateChatThreadRequest(topic=topic, participants=participants) - create_chat_thread_result = self._client.chat.create_chat_thread( + create_chat_thread_result_generated = self._client.chat.create_chat_thread( create_chat_thread_request=create_thread_request, repeatability_request_id=idempotency_token, **kwargs ) errors = None - if hasattr(create_chat_thread_result, "invalid_participants"): + if hasattr(create_chat_thread_result_generated, "invalid_participants"): errors = CommunicationErrorResponseConverter.convert( - participants=thread_participants or [], chat_errors=create_chat_thread_result.invalid_participants + participants=thread_participants or [], chat_errors=create_chat_thread_result_generated.invalid_participants ) chat_thread_properties = ChatThreadProperties._from_generated( # pylint:disable=protected-access - create_chat_thread_result.chat_thread + create_chat_thread_result_generated.chat_thread ) create_chat_thread_result = CreateChatThreadResult(chat_thread=chat_thread_properties, errors=errors) @@ -197,7 +198,7 @@ def list_chat_threads(self, **kwargs): results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs) + return cast(ItemPaged[ChatThreadItem], self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs)) @distributed_trace def delete_chat_thread( diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py index 8d63a6d35b0a..1a8c27202eba 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_chat_thread_client.py @@ -3,12 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast from urllib.parse import urlparse from azure.core.tracing.decorator import distributed_trace from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.paging import ItemPaged from ._shared.user_credential import CommunicationTokenCredential from ._shared.models import CommunicationIdentifier @@ -22,6 +23,8 @@ UpdateChatThreadRequest, ChatMessageType, SendChatMessageResult, + ChatError, + CommunicationIdentifierModel, ) from ._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties @@ -221,7 +224,7 @@ def list_read_receipts( results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_read_receipts( + return cast(ItemPaged[ChatMessageReadReceipt], self._client.chat_thread.list_chat_read_receipts( self._thread_id, max_page_size=results_per_page, skip=skip, @@ -229,7 +232,7 @@ def list_read_receipts( ChatMessageReadReceipt._from_generated(x) for x in objs # pylint:disable=protected-access ], **kwargs - ) + )) @distributed_trace def send_typing_notification( @@ -377,14 +380,13 @@ def list_messages( results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - a = self._client.chat_thread.list_chat_messages( + return cast(ItemPaged[ChatMessage], self._client.chat_thread.list_chat_messages( self._thread_id, max_page_size=results_per_page, start_time=start_time, cls=lambda objs: [ChatMessage._from_generated(x) for x in objs], # pylint:disable=protected-access **kwargs - ) - return a + )) @distributed_trace def update_message( @@ -484,13 +486,13 @@ def list_participants( results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_participants( + return cast(ItemPaged[ChatParticipant], self._client.chat_thread.list_chat_participants( self._thread_id, max_page_size=results_per_page, skip=skip, cls=lambda objs: [ChatParticipant._from_generated(x) for x in objs], # pylint:disable=protected-access **kwargs - ) + )) @distributed_trace def add_participants( @@ -564,7 +566,7 @@ def remove_participant( return self._client.chat_thread.remove_chat_participant( chat_thread_id=self._thread_id, - participant_communication_identifier=serialize_identifier(identifier), + participant_communication_identifier=CommunicationIdentifierModel(**serialize_identifier(identifier)), **kwargs ) diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py index 6b08da1d07ea..3e31b1ec2a99 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_communication_identifier_serializer.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import Dict, Any, TYPE_CHECKING +from typing import Dict, Any, TYPE_CHECKING, cast from ._shared.models import ( CommunicationIdentifier, @@ -31,7 +31,7 @@ def serialize_identifier(identifier): request_model = {"raw_id": identifier.raw_id} if identifier.kind and identifier.kind != CommunicationIdentifierKind.UNKNOWN: - request_model[identifier.kind] = dict(identifier.properties) + request_model[str(identifier.kind)] = cast(Any, identifier.properties) return request_model except AttributeError: raise TypeError( # pylint: disable=raise-missing-from @@ -52,7 +52,7 @@ def deserialize_identifier(identifier_model): raw_id = identifier_model.raw_id if identifier_model.communication_user: - return CommunicationUserIdentifier(raw_id, raw_id=raw_id) + return CommunicationUserIdentifier(identifier_model.communication_user.id, raw_id=raw_id) if identifier_model.phone_number: return PhoneNumberIdentifier(identifier_model.phone_number.value, raw_id=raw_id) if identifier_model.microsoft_teams_user: @@ -62,4 +62,4 @@ def deserialize_identifier(identifier_model): is_anonymous=identifier_model.microsoft_teams_user.is_anonymous, cloud=identifier_model.microsoft_teams_user.cloud, ) - return UnknownIdentifier(raw_id) + return UnknownIdentifier(raw_id or "") diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/policy.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/policy.py index 1843d22e83a2..211f2ff0cfcf 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/policy.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/policy.py @@ -66,7 +66,7 @@ def _sign_request(self, request): # Need URL() to get a correct encoded key value, from "%3A" to ":", when transport is in type AioHttpTransport. # There's a similar scenario in azure-storage-blob and azure-appconfiguration, the check logic is from there. try: - from yarl import URL + from yarl import URL # type: ignore[import-not-found] from azure.core.pipeline.transport import ( # pylint:disable=non-abstract-transport-import AioHttpTransport, ) diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py index c4dc8b52b098..b59fdb246fea 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_utils.py @@ -4,6 +4,12 @@ # license information. # -------------------------------------------------------------------------- +from typing import List, Tuple, Dict, TYPE_CHECKING + +if TYPE_CHECKING: + from ._models import ChatParticipant + from ._generated.models import ChatError + def _to_utc_datetime(value): return value.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -24,7 +30,7 @@ class CommunicationErrorResponseConverter(object): @classmethod def convert(cls, participants, chat_errors): - # type: (...) -> list[(ChatThreadParticipant, ChatError)] + # type: (...) -> List[Tuple[ChatParticipant, ChatError]] """ Util function to convert AddChatParticipantsResult. @@ -41,12 +47,12 @@ def convert(cls, participants, chat_errors): """ def create_dict(participants): - # type: (...) -> Dict(str, ChatThreadParticipant) + # type: (...) -> Dict[str, ChatParticipant] """ Create dictionary of id -> ChatParticipant - :param list participants: list of ChatThreadParticipant - :return: Dictionary of id -> ChatThreadParticipant + :param list participants: list of ChatParticipant + :return: Dictionary of id -> ChatParticipant :rtype: dict """ result = {} @@ -61,6 +67,7 @@ def create_dict(participants): if chat_errors is not None: for chat_error in chat_errors: _thread_participant = _thread_participants_dict.get(chat_error.target) - failed_chat_thread_participants.append((_thread_participant, chat_error)) + if _thread_participant is not None: + failed_chat_thread_participants.append((_thread_participant, chat_error)) return failed_chat_thread_participants diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py index 8761951a38da..59bea6addd52 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse # pylint: disable=unused-import,ungrouped-imports -from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, cast from datetime import datetime from uuid import uuid4 @@ -141,18 +141,18 @@ async def create_chat_thread(self, topic: str, **kwargs) -> CreateChatThreadResu create_thread_request = CreateChatThreadRequest(topic=topic, participants=participants) - create_chat_thread_result = await self._client.chat.create_chat_thread( + create_chat_thread_result_generated = await self._client.chat.create_chat_thread( create_chat_thread_request=create_thread_request, repeatability_request_id=idempotency_token, **kwargs ) errors = None - if hasattr(create_chat_thread_result, "invalid_participants"): + if hasattr(create_chat_thread_result_generated, "invalid_participants"): errors = CommunicationErrorResponseConverter.convert( - participants=thread_participants or [], chat_errors=create_chat_thread_result.invalid_participants + participants=thread_participants or [], chat_errors=create_chat_thread_result_generated.invalid_participants ) chat_thread = ChatThreadProperties._from_generated( # pylint:disable=protected-access - create_chat_thread_result.chat_thread + create_chat_thread_result_generated.chat_thread ) create_chat_thread_result = CreateChatThreadResult(chat_thread=chat_thread, errors=errors) @@ -181,7 +181,7 @@ def list_chat_threads(self, **kwargs: Any) -> AsyncItemPaged[ChatThreadItem]: results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs) + return cast(AsyncItemPaged[ChatThreadItem], self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs)) @distributed_trace_async async def delete_chat_thread(self, thread_id: str, **kwargs) -> None: diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py index a4ac904ba28a..4269f7d2edfc 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse # pylint: disable=unused-import,ungrouped-imports -from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, Tuple +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, Tuple, cast from datetime import datetime from azure.core.tracing.decorator import distributed_trace @@ -26,6 +26,7 @@ SendChatMessageResult, ChatMessageType, ChatError, + CommunicationIdentifierModel, ) from .._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties from .._shared.models import CommunicationIdentifier @@ -122,7 +123,7 @@ async def get_properties(self, **kwargs) -> ChatThreadProperties: return ChatThreadProperties._from_generated(chat_thread) # pylint:disable=protected-access @distributed_trace_async - async def update_topic(self, topic: str = None, **kwargs) -> None: + async def update_topic(self, topic: Optional[str] = None, **kwargs) -> None: """Updates a thread's properties. :param topic: Thread topic. If topic is not specified, the update will succeed but @@ -196,7 +197,7 @@ def list_read_receipts(self, **kwargs: Any) -> AsyncItemPaged[ChatMessageReadRec results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_read_receipts( + return cast(AsyncItemPaged[ChatMessageReadReceipt], self._client.chat_thread.list_chat_read_receipts( self._thread_id, max_page_size=results_per_page, skip=skip, @@ -204,7 +205,7 @@ def list_read_receipts(self, **kwargs: Any) -> AsyncItemPaged[ChatMessageReadRec ChatMessageReadReceipt._from_generated(x) for x in objs # pylint:disable=protected-access ], **kwargs - ) + )) @distributed_trace_async async def send_typing_notification(self, *, sender_display_name: Optional[str] = None, **kwargs) -> None: @@ -233,7 +234,7 @@ async def send_typing_notification(self, *, sender_display_name: Optional[str] = ) @distributed_trace_async - async def send_message(self, content: str, *, metadata: Dict[str, str] = None, **kwargs) -> SendChatMessageResult: + async def send_message(self, content: str, *, metadata: Optional[Dict[str, str]] = None, **kwargs) -> SendChatMessageResult: """Sends a message to a thread. :param content: Required. Chat message content. @@ -334,17 +335,17 @@ def list_messages(self, **kwargs: Any) -> AsyncItemPaged[ChatMessage]: results_per_page = kwargs.pop("results_per_page", None) start_time = kwargs.pop("start_time", None) - return self._client.chat_thread.list_chat_messages( + return cast(AsyncItemPaged[ChatMessage], self._client.chat_thread.list_chat_messages( self._thread_id, max_page_size=results_per_page, start_time=start_time, cls=lambda objs: [ChatMessage._from_generated(x) for x in objs], # pylint:disable=protected-access **kwargs - ) + )) @distributed_trace_async async def update_message( - self, message_id: str, content: str = None, *, metadata: Dict[str, str] = None, **kwargs + self, message_id: str, content: Optional[str] = None, *, metadata: Optional[Dict[str, str]] = None, **kwargs ) -> None: """Updates a message. @@ -426,13 +427,13 @@ def list_participants(self, **kwargs: Any) -> AsyncItemPaged[ChatParticipant]: results_per_page = kwargs.pop("results_per_page", None) skip = kwargs.pop("skip", None) - return self._client.chat_thread.list_chat_participants( + return cast(AsyncItemPaged[ChatParticipant], self._client.chat_thread.list_chat_participants( self._thread_id, max_page_size=results_per_page, skip=skip, cls=lambda objs: [ChatParticipant._from_generated(x) for x in objs], # pylint:disable=protected-access **kwargs - ) + )) @distributed_trace_async async def add_participants( @@ -498,7 +499,7 @@ async def remove_participant(self, identifier: CommunicationIdentifier, **kwargs return await self._client.chat_thread.remove_chat_participant( chat_thread_id=self._thread_id, - participant_communication_identifier=serialize_identifier(identifier), + participant_communication_identifier=CommunicationIdentifierModel(**serialize_identifier(identifier)), **kwargs ) diff --git a/sdk/communication/azure-communication-chat/pyproject.toml b/sdk/communication/azure-communication-chat/pyproject.toml index 05ee5668ed6a..343877d67195 100644 --- a/sdk/communication/azure-communication-chat/pyproject.toml +++ b/sdk/communication/azure-communication-chat/pyproject.toml @@ -1,4 +1,4 @@ [tool.azure-sdk-build] -mypy = false +mypy = true pyright = false type_check_samples = false