Skip to content

Commit 5ea1750

Browse files
l0lawrencekashifkhanswathipil
authored
[EH Emulator] WIP support for emulator (Azure#34103)
* add use_tls bool * adding on * async consumer * add for localhsot * add to producer * prod client * getting closer * reocgnize emulator slug if true * udate * sync websocket * remove print * async * update * d't use custom endpont * pass through * updates * comment Co-authored-by: Kashif Khan <[email protected]> * updates * comment * pass thrpugh * update * error * async * changes * fix * sync send/recieve * sync * nits * nit * add async * remove * test * test * update typing * pylint * update if * nit * add ipv6 test * set true at top level * remove true * use config * my changes * move everything to pyamqp * remove * nit * space * remove print * remove print + typing * merge error * nits * Update sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py Co-authored-by: Kashif Khan <[email protected]> * websocket port * apply to sync * websocket port * Update sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py Co-authored-by: Kashif Khan <[email protected]> --------- Co-authored-by: Kashif Khan <[email protected]> Co-authored-by: swathipil <[email protected]>
1 parent 7a47adf commit 5ea1750

File tree

13 files changed

+98
-15
lines changed

13 files changed

+98
-15
lines changed

sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
import functools
1111
import collections
12+
import re
1213
from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union
1314
from datetime import timedelta
1415
from urllib.parse import urlparse
@@ -68,20 +69,23 @@
6869
_LOGGER = logging.getLogger(__name__)
6970
_Address = collections.namedtuple("_Address", "hostname path")
7071

72+
def _is_local_endpoint(endpoint: str) -> bool:
73+
return re.match("^(127\\.[\\d.]+|[0:]+1|localhost)", endpoint.lower()) is not None
7174

7275
def _parse_conn_str(
7376
conn_str: str, # pylint:disable=unused-argument
7477
*,
7578
eventhub_name: Optional[str] = None,
7679
check_case: bool = False,
7780
**kwargs: Any
78-
) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]:
81+
) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int], bool]:
7982
endpoint = None
8083
shared_access_key_name = None
8184
shared_access_key = None
8285
entity_path: Optional[str] = None
8386
shared_access_signature: Optional[str] = None
8487
shared_access_signature_expiry = None
88+
use_emulator: Optional[str] = None
8589
conn_settings = core_parse_connection_string(
8690
conn_str, case_sensitive_keys=check_case
8791
)
@@ -95,6 +99,7 @@ def _parse_conn_str(
9599
# only sas check is non case sensitive for both conn str properties and internal use
96100
if key.lower() == "sharedaccesssignature":
97101
shared_access_signature = value
102+
use_emulator = conn_settings.get("UseDevelopmentEmulator")
98103

99104
if not check_case:
100105
endpoint = conn_settings.get("endpoint") or conn_settings.get("hostname")
@@ -104,6 +109,7 @@ def _parse_conn_str(
104109
shared_access_key = conn_settings.get("sharedaccesskey")
105110
entity_path = conn_settings.get("entitypath")
106111
shared_access_signature = conn_settings.get("sharedaccesssignature")
112+
use_emulator = conn_settings.get("usedevelopmentemulator")
107113

108114
if shared_access_signature:
109115
try:
@@ -124,11 +130,21 @@ def _parse_conn_str(
124130
# check that endpoint is valid
125131
if not endpoint:
126132
raise ValueError("Connection string is either blank or malformed.")
133+
127134
parsed = urlparse(endpoint)
128135
if not parsed.netloc:
129136
raise ValueError("Invalid Endpoint on the Connection String.")
130137
host = cast(str, parsed.netloc.strip())
131138

139+
emulator = use_emulator=="true"
140+
if emulator and not _is_local_endpoint(host):
141+
raise ValueError(
142+
"Invalid endpoint on the connection string. "
143+
"For development connection strings, should be in the format: "
144+
"Endpoint=sb://localhost;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>;"
145+
"UseDevelopmentEmulator=true;"
146+
)
147+
132148
if any([shared_access_key, shared_access_key_name]) and not all(
133149
[shared_access_key, shared_access_key_name]
134150
):
@@ -154,6 +170,7 @@ def _parse_conn_str(
154170
entity,
155171
str(shared_access_signature) if shared_access_signature else None,
156172
shared_access_signature_expiry,
173+
emulator
157174
)
158175

159176

@@ -327,11 +344,15 @@ def __init__(
327344

328345
@staticmethod
329346
def _from_connection_string(conn_str: str, **kwargs: Any) -> Dict[str, Any]:
330-
host, policy, key, entity, token, token_expiry = _parse_conn_str(
347+
host, policy, key, entity, token, token_expiry, emulator = _parse_conn_str(
331348
conn_str, **kwargs
332349
)
350+
333351
kwargs["fully_qualified_namespace"] = host
334352
kwargs["eventhub_name"] = entity
353+
# Check if emulator is in use, unset tls if it is
354+
if emulator:
355+
kwargs["use_tls"] = False
335356
if token and token_expiry:
336357
kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry)
337358
elif policy and key:

sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
send_timeout: int = 60,
3636
custom_endpoint_address: Optional[str] = None,
3737
connection_verify: Optional[str] = None,
38+
use_tls: bool = True,
3839
**kwargs: Any
3940
):
4041
self.user_agent = user_agent
@@ -59,6 +60,7 @@ def __init__(
5960
self.connection_verify = connection_verify
6061
self.custom_endpoint_hostname = None
6162
self.hostname = hostname
63+
self.use_tls = use_tls
6264

6365
if self.http_proxy or self.transport_type.value == TransportType.AmqpOverWebsocket.value:
6466
self.transport_type = TransportType.AmqpOverWebsocket

sdk/eventhub/azure-eventhub/azure/eventhub/_connection_string_parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def parse_connection_string(conn_str: str) -> "EventHubConnectionStringPropertie
9191
:return: A properties bag containing the parsed connection string.
9292
:rtype: ~azure.eventhub.EventHubConnectionStringProperties
9393
"""
94-
fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(
94+
fully_qualified_namespace, policy, key, entity, signature, emulator = _parse_conn_str(
9595
conn_str, check_case=True
9696
)[:-1]
9797
endpoint = "sb://" + fully_qualified_namespace + "/"
@@ -102,5 +102,6 @@ def parse_connection_string(conn_str: str) -> "EventHubConnectionStringPropertie
102102
"shared_access_signature": signature,
103103
"shared_access_key_name": policy,
104104
"shared_access_key": key,
105+
"emulator": emulator,
105106
}
106107
return EventHubConnectionStringProperties(**props)

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def __init__( # pylint:disable=too-many-locals,too-many-statements
126126
self._port = PORT
127127
self.state: Optional[ConnectionState] = None
128128

129+
# Set the port for AmqpOverWebsocket
130+
if transport_type.value == TransportType.AmqpOverWebsocket.value:
131+
self._port = WEBSOCKET_PORT
132+
129133
# Custom Endpoint
130134
custom_endpoint_address = kwargs.get("custom_endpoint_address")
131135
custom_endpoint = None
@@ -157,6 +161,7 @@ def __init__( # pylint:disable=too-many-locals,too-many-statements
157161
self._transport = sasl_transport(
158162
host=endpoint,
159163
credential=kwargs["sasl_credential"],
164+
port=self._port,
160165
custom_endpoint=custom_endpoint,
161166
socket_timeout=self._socket_timeout,
162167
network_trace_params=self._network_trace_params,

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
socket_timeout=SOCKET_TIMEOUT,
172172
socket_settings=None,
173173
raise_on_initial_eintr=True,
174+
use_tls: bool =True,
174175
**kwargs
175176
):
176177
self._quick_recv = None
@@ -186,6 +187,8 @@ def __init__(
186187
self.socket_settings = socket_settings
187188
self.socket_lock = Lock()
188189

190+
self._use_tls = use_tls
191+
189192
def connect(self):
190193
try:
191194
# are we already connected?
@@ -509,7 +512,8 @@ def __init__(
509512

510513
def _setup_transport(self):
511514
"""Wrap the socket in an SSL object."""
512-
self.sock = self._wrap_socket(self.sock, **self.sslopts)
515+
if self._use_tls:
516+
self.sock = self._wrap_socket(self.sock, **self.sslopts)
513517
self._quick_recv = self.sock.recv
514518

515519
def _wrap_socket(self, sock, context=None, **sslopts):
@@ -599,7 +603,8 @@ def _shutdown_transport(self):
599603
"""Unwrap a SSL socket, so we can call shutdown()."""
600604
if self.sock is not None:
601605
try:
602-
self.sock = self.sock.unwrap()
606+
if self._use_tls:
607+
self.sock = self.sock.unwrap()
603608
except OSError:
604609
pass
605610

@@ -737,11 +742,12 @@ def connect(self):
737742
) from None
738743
try:
739744
self.sock = create_connection(
740-
url="wss://{}".format(self._custom_endpoint or self._host),
745+
url="wss://{}".format(self._custom_endpoint or self._host) if self._use_tls
746+
else "ws://{}".format(self._custom_endpoint or self._host),
741747
subprotocols=[AMQP_WS_SUBPROTOCOL],
742748
timeout=self.socket_timeout, # timeout for read/write operations
743749
skip_utf8_validation=True,
744-
sslopt=self.sslopts,
750+
sslopt=self.sslopts if self._use_tls else None,
745751
http_proxy_host=http_proxy_host,
746752
http_proxy_port=http_proxy_port,
747753
http_proxy_auth=http_proxy_auth,

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ async def open_async(self, connection=None):
244244
self._external_connection = True
245245
if not self._connection:
246246
self._connection = Connection(
247-
"amqps://" + self._hostname,
247+
"amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname,
248248
sasl_credential=self._auth.sasl,
249249
ssl_opts={'ca_certs': self._connection_verify or certifi.where()},
250250
container_id=self._name,
@@ -257,6 +257,7 @@ async def open_async(self, connection=None):
257257
http_proxy=self._http_proxy,
258258
custom_endpoint_address=self._custom_endpoint_address,
259259
socket_timeout=self._socket_timeout,
260+
use_tls=self._use_tls,
260261
)
261262
await self._connection.open()
262263
if not self._session:

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def __init__(# pylint:disable=too-many-locals,too-many-statements
107107
self._port = PORT
108108
self.state: Optional[ConnectionState] = None
109109

110+
# Set the port for AmqpOverWebsocket
111+
if transport_type.value == TransportType.AmqpOverWebsocket.value:
112+
self._port = WEBSOCKET_PORT
113+
110114
# Custom Endpoint
111115
custom_endpoint_address = kwargs.get("custom_endpoint_address")
112116
custom_endpoint = None
@@ -141,6 +145,7 @@ def __init__(# pylint:disable=too-many-locals,too-many-statements
141145
self._transport: Union[SASLTransport, SASLWithWebSocket, AsyncTransport] = sasl_transport(
142146
host=endpoint,
143147
credential=kwargs["sasl_credential"],
148+
port=self._port,
144149
custom_endpoint=custom_endpoint,
145150
socket_timeout=self._socket_timeout,
146151
network_trace_params=self._network_trace_params,

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def __init__(
247247
ssl_opts=False,
248248
socket_settings=None,
249249
raise_on_initial_eintr=True,
250+
use_tls: bool = True,
250251
**kwargs, # pylint: disable=unused-argument
251252
):
252253
self.connected = False
@@ -260,6 +261,7 @@ def __init__(
260261
self.socket_lock = asyncio.Lock()
261262
self.sslopts = ssl_opts
262263
self.network_trace_params = kwargs.get('network_trace_params')
264+
self._use_tls = use_tls
263265

264266
async def connect(self):
265267
try:
@@ -280,10 +282,10 @@ async def connect(self):
280282
self.reader, self.writer = await asyncio.open_connection(
281283
host=self.host,
282284
port=self.port,
283-
ssl=self.sslopts,
285+
ssl=self.sslopts if self._use_tls else None,
284286
family=socket.AF_UNSPEC,
285287
proto=SOL_TCP,
286-
server_hostname=self.host if self.sslopts else None,
288+
server_hostname=self.host if self._use_tls else None,
287289
)
288290
self.connected = True
289291
sock = self.writer.transport.get_extra_info("socket")
@@ -436,6 +438,7 @@ def __init__(
436438
port=WEBSOCKET_PORT,
437439
socket_timeout=CONNECT_TIMEOUT,
438440
ssl_opts=None,
441+
use_tls: bool =True,
439442
**kwargs
440443
):
441444
self._read_buffer = BytesIO()
@@ -449,6 +452,7 @@ def __init__(
449452
self._http_proxy = kwargs.get("http_proxy", None)
450453
self.connected = False
451454
self.network_trace_params = kwargs.get('network_trace_params')
455+
self._use_tls = use_tls
452456

453457
async def connect(self):
454458
self.sslopts = self._build_ssl_opts(self.sslopts)
@@ -478,9 +482,9 @@ async def connect(self):
478482

479483
self.session = ClientSession()
480484
if self._custom_endpoint:
481-
url = f"wss://{self._custom_endpoint}"
485+
url = f"wss://{self._custom_endpoint}" if self._use_tls else f"ws://{self._custom_endpoint}"
482486
else:
483-
url = f"wss://{self.host}"
487+
url = f"wss://{self.host}" if self._use_tls else f"ws://{self.host}"
484488
parsed_url = urlsplit(url)
485489
url = f"{parsed_url.scheme}://{parsed_url.netloc}:{self.port}{parsed_url.path}"
486490

@@ -500,7 +504,7 @@ async def connect(self):
500504
autoclose=False,
501505
proxy=http_proxy_host,
502506
proxy_auth=http_proxy_auth,
503-
ssl=self.sslopts,
507+
ssl=self.sslopts if self._use_tls else None,
504508
heartbeat=DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS,
505509
)
506510
except ClientConnectorError as exc:

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def __init__(self, hostname, **kwargs):
214214
self._custom_endpoint_address = kwargs.get("custom_endpoint_address")
215215
self._connection_verify = kwargs.get("connection_verify")
216216

217+
# Emulator
218+
self._use_tls: bool = kwargs.get("use_tls", True)
219+
217220
def __enter__(self):
218221
"""Run Client in a context manager.
219222
@@ -312,7 +315,7 @@ def open(self, connection=None):
312315
self._external_connection = True
313316
elif not self._connection:
314317
self._connection = Connection(
315-
"amqps://" + self._hostname,
318+
"amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname,
316319
sasl_credential=self._auth.sasl,
317320
ssl_opts={"ca_certs": self._connection_verify or certifi.where()},
318321
container_id=self._name,
@@ -325,6 +328,7 @@ def open(self, connection=None):
325328
http_proxy=self._http_proxy,
326329
custom_endpoint_address=self._custom_endpoint_address,
327330
socket_timeout=self._socket_timeout,
331+
use_tls=self._use_tls,
328332
)
329333
self._connection.open()
330334
if not self._session:

sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def create_send_client(# pylint: disable=unused-argument
351351
link_properties=link_properties,
352352
properties=properties,
353353
client_name=client_name,
354+
use_tls=config.use_tls,
354355
**kwargs,
355356
)
356357

@@ -507,6 +508,7 @@ def create_receive_client(
507508
keep_alive_interval=keep_alive_interval,
508509
streaming_receive=streaming_receive,
509510
timeout=timeout,
511+
use_tls=config.use_tls,
510512
**kwargs,
511513
)
512514

@@ -604,6 +606,7 @@ def create_mgmt_client(
604606
http_proxy=config.http_proxy,
605607
custom_endpoint_address=config.custom_endpoint_address,
606608
connection_verify=config.connection_verify,
609+
use_tls=config.use_tls,
607610
)
608611

609612
@staticmethod

0 commit comments

Comments
 (0)