Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
from urllib.parse import urlparse

from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.paging import ItemPaged

from ._chat_thread_client import ChatThreadClient
from ._shared.user_credential import CommunicationTokenCredential
from ._generated import AzureCommunicationChatService
from ._generated.models import CreateChatThreadRequest
from ._generated.models import CreateChatThreadRequest, ChatThreadItem
from ._models import ChatThreadProperties, CreateChatThreadResult
from ._utils import ( # pylint: disable=unused-import
_to_utc_datetime,
Expand Down Expand Up @@ -156,18 +157,18 @@ def create_chat_thread(

create_thread_request = CreateChatThreadRequest(topic=topic, participants=participants)

create_chat_thread_result = self._client.chat.create_chat_thread(
create_chat_thread_result_generated = self._client.chat.create_chat_thread(
create_chat_thread_request=create_thread_request, repeatability_request_id=idempotency_token, **kwargs
)

errors = None
if hasattr(create_chat_thread_result, "invalid_participants"):
if hasattr(create_chat_thread_result_generated, "invalid_participants"):
errors = CommunicationErrorResponseConverter.convert(
participants=thread_participants or [], chat_errors=create_chat_thread_result.invalid_participants
participants=thread_participants or [], chat_errors=create_chat_thread_result_generated.invalid_participants
)

chat_thread_properties = ChatThreadProperties._from_generated( # pylint:disable=protected-access
create_chat_thread_result.chat_thread
create_chat_thread_result_generated.chat_thread
)

create_chat_thread_result = CreateChatThreadResult(chat_thread=chat_thread_properties, errors=errors)
Expand Down Expand Up @@ -197,7 +198,7 @@ def list_chat_threads(self, **kwargs):
results_per_page = kwargs.pop("results_per_page", None)
start_time = kwargs.pop("start_time", None)

return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs)
return cast(ItemPaged[ChatThreadItem], self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs))

@distributed_trace
def delete_chat_thread(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast

from urllib.parse import urlparse

from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.paging import ItemPaged

from ._shared.user_credential import CommunicationTokenCredential
from ._shared.models import CommunicationIdentifier
Expand All @@ -22,6 +23,8 @@
UpdateChatThreadRequest,
ChatMessageType,
SendChatMessageResult,
ChatError,
CommunicationIdentifierModel,
)
from ._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties

Expand Down Expand Up @@ -221,15 +224,15 @@ def list_read_receipts(
results_per_page = kwargs.pop("results_per_page", None)
skip = kwargs.pop("skip", None)

return self._client.chat_thread.list_chat_read_receipts(
return cast(ItemPaged[ChatMessageReadReceipt], self._client.chat_thread.list_chat_read_receipts(
self._thread_id,
max_page_size=results_per_page,
skip=skip,
cls=lambda objs: [
ChatMessageReadReceipt._from_generated(x) for x in objs # pylint:disable=protected-access
],
**kwargs
)
))

@distributed_trace
def send_typing_notification(
Expand Down Expand Up @@ -377,14 +380,13 @@ def list_messages(
results_per_page = kwargs.pop("results_per_page", None)
start_time = kwargs.pop("start_time", None)

a = self._client.chat_thread.list_chat_messages(
return cast(ItemPaged[ChatMessage], self._client.chat_thread.list_chat_messages(
self._thread_id,
max_page_size=results_per_page,
start_time=start_time,
cls=lambda objs: [ChatMessage._from_generated(x) for x in objs], # pylint:disable=protected-access
**kwargs
)
return a
))

@distributed_trace
def update_message(
Expand Down Expand Up @@ -484,13 +486,13 @@ def list_participants(
results_per_page = kwargs.pop("results_per_page", None)
skip = kwargs.pop("skip", None)

return self._client.chat_thread.list_chat_participants(
return cast(ItemPaged[ChatParticipant], self._client.chat_thread.list_chat_participants(
self._thread_id,
max_page_size=results_per_page,
skip=skip,
cls=lambda objs: [ChatParticipant._from_generated(x) for x in objs], # pylint:disable=protected-access
**kwargs
)
))

@distributed_trace
def add_participants(
Expand Down Expand Up @@ -564,7 +566,7 @@ def remove_participant(

return self._client.chat_thread.remove_chat_participant(
chat_thread_id=self._thread_id,
participant_communication_identifier=serialize_identifier(identifier),
participant_communication_identifier=CommunicationIdentifierModel(**serialize_identifier(identifier)),
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Dict, Any, TYPE_CHECKING
from typing import Dict, Any, TYPE_CHECKING, cast

from ._shared.models import (
CommunicationIdentifier,
Expand Down Expand Up @@ -31,7 +31,7 @@ def serialize_identifier(identifier):
request_model = {"raw_id": identifier.raw_id}

if identifier.kind and identifier.kind != CommunicationIdentifierKind.UNKNOWN:
request_model[identifier.kind] = dict(identifier.properties)
request_model[str(identifier.kind)] = cast(Any, identifier.properties)
return request_model
except AttributeError:
raise TypeError( # pylint: disable=raise-missing-from
Expand All @@ -52,7 +52,7 @@ def deserialize_identifier(identifier_model):
raw_id = identifier_model.raw_id

if identifier_model.communication_user:
return CommunicationUserIdentifier(raw_id, raw_id=raw_id)
return CommunicationUserIdentifier(identifier_model.communication_user.id, raw_id=raw_id)
if identifier_model.phone_number:
return PhoneNumberIdentifier(identifier_model.phone_number.value, raw_id=raw_id)
if identifier_model.microsoft_teams_user:
Expand All @@ -62,4 +62,4 @@ def deserialize_identifier(identifier_model):
is_anonymous=identifier_model.microsoft_teams_user.is_anonymous,
cloud=identifier_model.microsoft_teams_user.cloud,
)
return UnknownIdentifier(raw_id)
return UnknownIdentifier(raw_id or "")
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _sign_request(self, request):
# Need URL() to get a correct encoded key value, from "%3A" to ":", when transport is in type AioHttpTransport.
# There's a similar scenario in azure-storage-blob and azure-appconfiguration, the check logic is from there.
try:
from yarl import URL
from yarl import URL # type: ignore[import-not-found]
from azure.core.pipeline.transport import ( # pylint:disable=non-abstract-transport-import
AioHttpTransport,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
# license information.
# --------------------------------------------------------------------------

from typing import List, Tuple, Dict, TYPE_CHECKING

if TYPE_CHECKING:
from ._models import ChatParticipant
from ._generated.models import ChatError


def _to_utc_datetime(value):
return value.strftime("%Y-%m-%dT%H:%M:%SZ")
Expand All @@ -24,7 +30,7 @@ class CommunicationErrorResponseConverter(object):

@classmethod
def convert(cls, participants, chat_errors):
# type: (...) -> list[(ChatThreadParticipant, ChatError)]
# type: (...) -> List[Tuple[ChatParticipant, ChatError]]
"""
Util function to convert AddChatParticipantsResult.

Expand All @@ -41,12 +47,12 @@ def convert(cls, participants, chat_errors):
"""

def create_dict(participants):
# type: (...) -> Dict(str, ChatThreadParticipant)
# type: (...) -> Dict[str, ChatParticipant]
"""
Create dictionary of id -> ChatParticipant

:param list participants: list of ChatThreadParticipant
:return: Dictionary of id -> ChatThreadParticipant
:param list participants: list of ChatParticipant
:return: Dictionary of id -> ChatParticipant
:rtype: dict
"""
result = {}
Expand All @@ -61,6 +67,7 @@ def create_dict(participants):
if chat_errors is not None:
for chat_error in chat_errors:
_thread_participant = _thread_participants_dict.get(chat_error.target)
failed_chat_thread_participants.append((_thread_participant, chat_error))
if _thread_participant is not None:
failed_chat_thread_participants.append((_thread_participant, chat_error))

return failed_chat_thread_participants
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib.parse import urlparse

# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, cast
from datetime import datetime
from uuid import uuid4

Expand Down Expand Up @@ -141,18 +141,18 @@ async def create_chat_thread(self, topic: str, **kwargs) -> CreateChatThreadResu

create_thread_request = CreateChatThreadRequest(topic=topic, participants=participants)

create_chat_thread_result = await self._client.chat.create_chat_thread(
create_chat_thread_result_generated = await self._client.chat.create_chat_thread(
create_chat_thread_request=create_thread_request, repeatability_request_id=idempotency_token, **kwargs
)

errors = None
if hasattr(create_chat_thread_result, "invalid_participants"):
if hasattr(create_chat_thread_result_generated, "invalid_participants"):
errors = CommunicationErrorResponseConverter.convert(
participants=thread_participants or [], chat_errors=create_chat_thread_result.invalid_participants
participants=thread_participants or [], chat_errors=create_chat_thread_result_generated.invalid_participants
)

chat_thread = ChatThreadProperties._from_generated( # pylint:disable=protected-access
create_chat_thread_result.chat_thread
create_chat_thread_result_generated.chat_thread
)

create_chat_thread_result = CreateChatThreadResult(chat_thread=chat_thread, errors=errors)
Expand Down Expand Up @@ -181,7 +181,7 @@ def list_chat_threads(self, **kwargs: Any) -> AsyncItemPaged[ChatThreadItem]:
results_per_page = kwargs.pop("results_per_page", None)
start_time = kwargs.pop("start_time", None)

return self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs)
return cast(AsyncItemPaged[ChatThreadItem], self._client.chat.list_chat_threads(max_page_size=results_per_page, start_time=start_time, **kwargs))

@distributed_trace_async
async def delete_chat_thread(self, thread_id: str, **kwargs) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib.parse import urlparse

# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, Tuple
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, Tuple, cast
from datetime import datetime

from azure.core.tracing.decorator import distributed_trace
Expand All @@ -26,6 +26,7 @@
SendChatMessageResult,
ChatMessageType,
ChatError,
CommunicationIdentifierModel,
)
from .._models import ChatParticipant, ChatMessage, ChatMessageReadReceipt, ChatThreadProperties
from .._shared.models import CommunicationIdentifier
Expand Down Expand Up @@ -122,7 +123,7 @@ async def get_properties(self, **kwargs) -> ChatThreadProperties:
return ChatThreadProperties._from_generated(chat_thread) # pylint:disable=protected-access

@distributed_trace_async
async def update_topic(self, topic: str = None, **kwargs) -> None:
async def update_topic(self, topic: Optional[str] = None, **kwargs) -> None:
"""Updates a thread's properties.

:param topic: Thread topic. If topic is not specified, the update will succeed but
Expand Down Expand Up @@ -196,15 +197,15 @@ def list_read_receipts(self, **kwargs: Any) -> AsyncItemPaged[ChatMessageReadRec
results_per_page = kwargs.pop("results_per_page", None)
skip = kwargs.pop("skip", None)

return self._client.chat_thread.list_chat_read_receipts(
return cast(AsyncItemPaged[ChatMessageReadReceipt], self._client.chat_thread.list_chat_read_receipts(
self._thread_id,
max_page_size=results_per_page,
skip=skip,
cls=lambda objs: [
ChatMessageReadReceipt._from_generated(x) for x in objs # pylint:disable=protected-access
],
**kwargs
)
))

@distributed_trace_async
async def send_typing_notification(self, *, sender_display_name: Optional[str] = None, **kwargs) -> None:
Expand Down Expand Up @@ -233,7 +234,7 @@ async def send_typing_notification(self, *, sender_display_name: Optional[str] =
)

@distributed_trace_async
async def send_message(self, content: str, *, metadata: Dict[str, str] = None, **kwargs) -> SendChatMessageResult:
async def send_message(self, content: str, *, metadata: Optional[Dict[str, str]] = None, **kwargs) -> SendChatMessageResult:
"""Sends a message to a thread.

:param content: Required. Chat message content.
Expand Down Expand Up @@ -334,17 +335,17 @@ def list_messages(self, **kwargs: Any) -> AsyncItemPaged[ChatMessage]:
results_per_page = kwargs.pop("results_per_page", None)
start_time = kwargs.pop("start_time", None)

return self._client.chat_thread.list_chat_messages(
return cast(AsyncItemPaged[ChatMessage], self._client.chat_thread.list_chat_messages(
self._thread_id,
max_page_size=results_per_page,
start_time=start_time,
cls=lambda objs: [ChatMessage._from_generated(x) for x in objs], # pylint:disable=protected-access
**kwargs
)
))

@distributed_trace_async
async def update_message(
self, message_id: str, content: str = None, *, metadata: Dict[str, str] = None, **kwargs
self, message_id: str, content: Optional[str] = None, *, metadata: Optional[Dict[str, str]] = None, **kwargs
) -> None:
"""Updates a message.

Expand Down Expand Up @@ -426,13 +427,13 @@ def list_participants(self, **kwargs: Any) -> AsyncItemPaged[ChatParticipant]:
results_per_page = kwargs.pop("results_per_page", None)
skip = kwargs.pop("skip", None)

return self._client.chat_thread.list_chat_participants(
return cast(AsyncItemPaged[ChatParticipant], self._client.chat_thread.list_chat_participants(
self._thread_id,
max_page_size=results_per_page,
skip=skip,
cls=lambda objs: [ChatParticipant._from_generated(x) for x in objs], # pylint:disable=protected-access
**kwargs
)
))

@distributed_trace_async
async def add_participants(
Expand Down Expand Up @@ -498,7 +499,7 @@ async def remove_participant(self, identifier: CommunicationIdentifier, **kwargs

return await self._client.chat_thread.remove_chat_participant(
chat_thread_id=self._thread_id,
participant_communication_identifier=serialize_identifier(identifier),
participant_communication_identifier=CommunicationIdentifierModel(**serialize_identifier(identifier)),
**kwargs
)

Expand Down
2 changes: 1 addition & 1 deletion sdk/communication/azure-communication-chat/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[tool.azure-sdk-build]
mypy = false
mypy = true
pyright = false
type_check_samples = false