Skip to content

Commit 6c67965

Browse files
[EH] improve typing (#33338)
* remove py2 typehints * pyright * adding typing * more * changes * changes * fix * fix * pylint * remove * ah * changes * update * pylint * client base changes * typing * adding more typing * remove ignore pyamqp * pull swathi changes * fix sample errors * pausing here * fixes * checkpointstore typing * pylint * connstate * undo samples typing * samples * changes * update * fwd ref * settler * try this * change where protocol is called from? * revert amqp typing * swathi changes * remove * emove * initial pass * nit * update exceptions * add back msg_backcompat for errors * revert changes * pylint * align with swathis pyamqp * mising default * pyamqp error * update error * update * pylint * Update sdk/eventhub/azure-eventhub-checkpointstoreblob-aio/azure/eventhub/extensions/checkpointstoreblobaio/_blobstoragecsaio.py Co-authored-by: Kashif Khan <[email protected]> * pr comments * update * kwargs.pop * fixing broken tests * remove type ignore --------- Co-authored-by: Kashif Khan <[email protected]> Co-authored-by: Kashif Khan <[email protected]>
1 parent a305fa9 commit 6c67965

33 files changed

+697
-352
lines changed

sdk/eventhub/azure-eventhub-checkpointstoreblob-aio/azure/eventhub/extensions/checkpointstoreblobaio/_blobstoragecsaio.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
5-
from typing import Iterable, Dict, Any, Optional
5+
from typing import Iterable, Dict, Any, Optional, Union, TYPE_CHECKING
66
import logging
77
import copy
88
from collections import defaultdict
@@ -13,6 +13,9 @@
1313
from ._vendor.storage.blob.aio import ContainerClient, BlobClient
1414
from ._vendor.storage.blob._shared.base_client import parse_connection_str
1515

16+
if TYPE_CHECKING:
17+
from azure.core.credentials_async import AsyncTokenCredential
18+
from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential
1619

1720
logger = logging.getLogger(__name__)
1821
UPLOAD_DATA = ""
@@ -41,11 +44,17 @@ class BlobCheckpointStore(CheckpointStore):
4144
The hostname of the secondary endpoint.
4245
"""
4346

44-
def __init__(self, blob_account_url, container_name, *, credential=None, **kwargs):
45-
# type(str, str, Optional[Any], Any) -> None
47+
def __init__(
48+
self,
49+
blob_account_url: str,
50+
container_name: str,
51+
*,
52+
credential: Optional[Union["AsyncTokenCredential", "AzureNamedKeyCredential", "AzureSasCredential"]] = None,
53+
api_version: str = '2019-07-07',
54+
**kwargs: Any
55+
) -> None:
4656
self._container_client = kwargs.pop("container_client", None)
4757
if not self._container_client:
48-
api_version = kwargs.pop("api_version", None)
4958
if api_version:
5059
headers = kwargs.get("headers")
5160
if headers:
@@ -59,7 +68,12 @@ def __init__(self, blob_account_url, container_name, *, credential=None, **kwarg
5968

6069
@classmethod
6170
def from_connection_string(
62-
cls, conn_str: str, container_name: str, *, credential: Optional[Any] = None, **kwargs: Any
71+
cls,
72+
conn_str: str,
73+
container_name: str,
74+
*,
75+
credential: Optional[Union["AsyncTokenCredential", "AzureNamedKeyCredential", "AzureSasCredential"]] = None,
76+
**kwargs: Any
6377
) -> "BlobCheckpointStore":
6478
"""Create BlobCheckpointStore from a storage connection string.
6579
@@ -88,11 +102,11 @@ def from_connection_string(
88102

89103
return cls(account_url, container_name, credential=credential, **kwargs)
90104

91-
async def __aenter__(self):
105+
async def __aenter__(self) -> "BlobCheckpointStore":
92106
await self._container_client.__aenter__()
93107
return self
94108

95-
async def __aexit__(self, *args):
109+
async def __aexit__(self, *args: Any) -> None:
96110
await self._container_client.__aexit__(*args)
97111

98112
def _get_blob_client(self, blob_name: str) -> BlobClient:
@@ -291,8 +305,8 @@ async def update_checkpoint(self, checkpoint: Dict[str, Any], **kwargs: Any) ->
291305
)
292306

293307
async def list_checkpoints(
294-
self, fully_qualified_namespace, eventhub_name, consumer_group, **kwargs
295-
):
308+
self, fully_qualified_namespace: str, eventhub_name: str, consumer_group: str, **kwargs: Any
309+
) -> Iterable[Dict[str, Any]]:
296310
"""List the updated checkpoints from the storage blob.
297311
298312
:param str fully_qualified_namespace: The fully qualified namespace that the Event Hub belongs to.

sdk/eventhub/azure-eventhub-checkpointstoreblob/azure/eventhub/extensions/checkpointstoreblob/_blobstoragecs.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
5-
from typing import Dict, Optional, Any, Iterable, Union
5+
from typing import Dict, Optional, Any, Iterable, Union, TYPE_CHECKING
66
import logging
77
import time
88
import calendar
@@ -16,6 +16,9 @@
1616
from ._vendor.storage.blob import BlobClient, ContainerClient
1717
from ._vendor.storage.blob._shared.base_client import parse_connection_str
1818

19+
if TYPE_CHECKING:
20+
from azure.core.credentials import TokenCredential, AzureSasCredential, AzureNamedKeyCredential
21+
1922
logger = logging.getLogger(__name__)
2023
UPLOAD_DATA = ""
2124

@@ -63,8 +66,14 @@ class BlobCheckpointStore(CheckpointStore):
6366
6467
"""
6568

66-
def __init__(self, blob_account_url, container_name, credential=None, **kwargs):
67-
# type(str, str, Optional[Any], Any) -> None
69+
def __init__(
70+
self,
71+
blob_account_url: str,
72+
container_name: str,
73+
credential: Optional[Union["AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None,
74+
api_version: str = '2019-07-07',
75+
**kwargs: Any
76+
) -> None:
6877
self._container_client = kwargs.pop("container_client", None)
6978
if not self._container_client:
7079
api_version = kwargs.pop("api_version", None)
@@ -81,9 +90,12 @@ def __init__(self, blob_account_url, container_name, credential=None, **kwargs):
8190

8291
@classmethod
8392
def from_connection_string(
84-
cls, conn_str, container_name, credential=None, **kwargs
85-
):
86-
# type: (str, str, Optional[Any], Any) -> BlobCheckpointStore
93+
cls,
94+
conn_str: str,
95+
container_name: str,
96+
credential: Optional[Union["AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None,
97+
**kwargs: Any
98+
) -> "BlobCheckpointStore":
8799
"""Create BlobCheckpointStore from a storage connection string.
88100
89101
:param str conn_str:
@@ -111,11 +123,11 @@ def from_connection_string(
111123

112124
return cls(account_url, container_name, credential=credential, **kwargs)
113125

114-
def __enter__(self):
126+
def __enter__(self) -> "BlobCheckpointStore":
115127
self._container_client.__enter__()
116128
return self
117129

118-
def __exit__(self, *args):
130+
def __exit__(self, *args: Any) -> None:
119131
self._container_client.__exit__(*args)
120132

121133
def _get_blob_client(self, blob_name):
@@ -183,8 +195,13 @@ def _claim_one_partition(self, ownership, **kwargs):
183195
)
184196
return updated_ownership # Keep the ownership if an unexpected error happens
185197

186-
def list_ownership(self, fully_qualified_namespace, eventhub_name, consumer_group, **kwargs):
187-
# type: (str, str, str, Any) -> Iterable[Dict[str, Any]]
198+
def list_ownership(
199+
self,
200+
fully_qualified_namespace: str,
201+
eventhub_name: str,
202+
consumer_group: str,
203+
**kwargs: Any
204+
) -> Iterable[Dict[str, Any]]:
188205
"""Retrieves a complete ownership list from the storage blob.
189206
190207
:param str fully_qualified_namespace: The fully qualified namespace that the Event Hub belongs to.
@@ -238,8 +255,11 @@ def list_ownership(self, fully_qualified_namespace, eventhub_name, consumer_grou
238255
)
239256
raise
240257

241-
def claim_ownership(self, ownership_list, **kwargs):
242-
# type: (Iterable[Dict[str, Any]], Any) -> Iterable[Dict[str, Any]]
258+
def claim_ownership(
259+
self,
260+
ownership_list: Iterable[Dict[str, Any]],
261+
**kwargs: Any
262+
) -> Iterable[Dict[str, Any]]:
243263
"""Tries to claim ownership for a list of specified partitions.
244264
245265
:param iterable[dict[str, any]] ownership_list: Iterable of dictionaries containing all the ownerships to claim.
@@ -265,8 +285,7 @@ def claim_ownership(self, ownership_list, **kwargs):
265285
pass
266286
return gathered_results
267287

268-
def update_checkpoint(self, checkpoint, **kwargs):
269-
# type: (Dict[str, Optional[Union[str, int]]], Any) -> None
288+
def update_checkpoint(self, checkpoint: Dict[str, Union[str, int]], **kwargs: Any) -> None:
270289
"""Updates the checkpoint using the given information for the offset, associated partition and
271290
consumer group in the storage blob.
272291
@@ -309,9 +328,8 @@ def update_checkpoint(self, checkpoint, **kwargs):
309328
)
310329

311330
def list_checkpoints(
312-
self, fully_qualified_namespace, eventhub_name, consumer_group, **kwargs
313-
):
314-
# type: (str, str, str, Any) -> Iterable[Dict[str, Any]]
331+
self, fully_qualified_namespace: str, eventhub_name: str, consumer_group: str, **kwargs: Any
332+
) -> Iterable[Dict[str, Any]]:
315333
"""List the updated checkpoints from the storage blob.
316334
317335
:param str fully_qualified_namespace: The fully qualified namespace that the Event Hub belongs to.
@@ -351,5 +369,5 @@ def list_checkpoints(
351369
result.append(checkpoint)
352370
return result
353371

354-
def close(self):
372+
def close(self) -> None:
355373
self._container_client.__exit__()

sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,20 @@
5050
"EventHubSharedKeyCredential",
5151
TokenCredential,
5252
]
53+
from ._consumer_client import EventHubConsumerClient
54+
from ._producer_client import EventHubProducerClient
55+
from ._transport._base import AmqpTransport
5356
try:
5457
from uamqp import Message as uamqp_Message
5558
from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth
59+
from uamqp import ReceiveClient as uamqp_AMQPRecieveClient
60+
from uamqp import SendClient as uamqp_AMQPSendClient
5661
except ImportError:
57-
uamqp_Message = None
58-
uamqp_JWTTokenAuth = None
62+
pass
5963
from ._pyamqp.message import Message
6064
from ._pyamqp.authentication import JWTTokenAuth
65+
from ._pyamqp import ReceiveClient as pyamqp_AMQPRecieveClient
66+
from ._pyamqp import SendClient as pyamqp_AMQPSendClient
6167

6268
_LOGGER = logging.getLogger(__name__)
6369
_Address = collections.namedtuple("_Address", "hostname path")
@@ -333,7 +339,7 @@ def _from_connection_string(conn_str: str, **kwargs: Any) -> Dict[str, Any]:
333339
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
334340
return kwargs
335341

336-
def _create_auth(self) -> Union[uamqp_JWTTokenAuth, JWTTokenAuth]:
342+
def _create_auth(self) -> Union["uamqp_JWTTokenAuth", JWTTokenAuth]:
337343
"""
338344
Create an ~uamqp.authentication.SASTokenAuth instance
339345
to authenticate the session.
@@ -414,7 +420,8 @@ def _management_request(
414420
mgmt_client.open(connection=conn)
415421
while not mgmt_client.client_ready():
416422
time.sleep(0.05)
417-
mgmt_msg.application_properties[
423+
424+
cast(Dict[Union[str, bytes], Any], mgmt_msg.application_properties)[
418425
"security_token"
419426
] = self._amqp_transport.get_updated_token(mgmt_auth)
420427
status_code, description, response = self._amqp_transport.mgmt_client_request(
@@ -510,11 +517,22 @@ def _close(self) -> None:
510517
self._conn_manager.close_connection()
511518

512519

513-
class ConsumerProducerMixin(object):
514-
def __enter__(self):
520+
class ConsumerProducerMixin():
521+
522+
def __init__(self) -> None:
523+
self._handler: Union[
524+
uamqp_AMQPRecieveClient,
525+
pyamqp_AMQPRecieveClient,
526+
uamqp_AMQPSendClient,
527+
pyamqp_AMQPSendClient]
528+
self._client: Union[EventHubConsumerClient, EventHubProducerClient]
529+
self._amqp_transport: "AmqpTransport"
530+
self._max_message_size_on_link: Optional[int] = None
531+
532+
def __enter__(self) -> ConsumerProducerMixin:
515533
return self
516534

517-
def __exit__(self, exc_type, exc_val, exc_tb):
535+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
518536
self.close()
519537

520538
def _create_handler(self, auth):
@@ -526,8 +544,12 @@ def _check_closed(self):
526544
f"{self._name} has been closed. Please create a new one to handle event data."
527545
)
528546

529-
def _open(self):
530-
"""Open the EventHubConsumer/EventHubProducer using the supplied connection."""
547+
def _open(self) -> bool:
548+
"""Open the EventHubConsumer/EventHubProducer using the supplied connection.
549+
550+
:return: Whether the EventHubConsumer/EventHubProducer is ready to use.
551+
:rtype: bool
552+
"""
531553
# pylint: disable=protected-access
532554
if not self.running:
533555
if self._handler:
@@ -544,9 +566,11 @@ def _open(self):
544566
self._amqp_transport.get_remote_max_message_size(self._handler)
545567
or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES
546568
)
547-
self.running = True
569+
self.running: bool = True
570+
return True
571+
return False
548572

549-
def _close_handler(self):
573+
def _close_handler(self) -> None:
550574
if self._handler:
551575
self._handler.close() # close the link (sharing connection) or connection (not sharing)
552576
self.running = False

sdk/eventhub/azure-eventhub/azure/eventhub/_common.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@
6464
BatchMessage,
6565
)
6666
except ImportError:
67-
Message = None
68-
BatchMessage = None
67+
pass
68+
6969
from ._transport._base import AmqpTransport
7070

7171
MessageContent = TypedDict("MessageContent", {"content": bytes, "content_type": str})
@@ -76,8 +76,8 @@
7676
bytes,
7777
bool,
7878
str,
79-
Dict,
80-
List,
79+
Dict[str, Any],
80+
List[Any],
8181
uuid.UUID,
8282
]
8383
]
@@ -348,12 +348,12 @@ def properties(self) -> Dict[Union[str, bytes], Any]:
348348
return self._raw_amqp_message.application_properties
349349

350350
@properties.setter
351-
def properties(self, value: Dict[Union[str, bytes], Any]):
351+
def properties(self, value: Dict[Union[str, bytes], Any]) -> None:
352352
"""Application-defined properties on the event.
353353
354354
:param dict[str, any] or dict[bytes, any] value: The application properties for the EventData.
355355
"""
356-
properties = None if value is None else dict(value)
356+
properties = None if value is None else value
357357
self._raw_amqp_message.application_properties = properties
358358

359359
@property
@@ -550,7 +550,7 @@ def __init__(
550550
max_size_in_bytes: Optional[int] = None,
551551
partition_id: Optional[str] = None,
552552
partition_key: Optional[Union[str, bytes]] = None,
553-
**kwargs,
553+
**kwargs: Any,
554554
) -> None:
555555
self._amqp_transport = kwargs.pop("amqp_transport", PyamqpTransport)
556556
self._tracing_attributes: Dict[str, Union[str, int]] = kwargs.pop("tracing_attributes", {})
@@ -571,7 +571,7 @@ def __init__(
571571
self._message, self._partition_key
572572
)
573573
self._size = self._amqp_transport.get_batch_message_encoded_size(self._message)
574-
self.max_size_in_bytes = (
574+
self.max_size_in_bytes: int = (
575575
max_size_in_bytes or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES
576576
)
577577

sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class _ConnectionMode(Enum):
4949

5050

5151
class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes
52-
def __init__(self, **kwargs):
52+
def __init__(self, **kwargs: Any):
5353
self._lock = Lock()
5454
self._conn: Union[Connection, uamqp_Connection] = None
5555

0 commit comments

Comments
 (0)