1010import uuid
1111import warnings
1212import asyncio
13-
13+ from functools import partial
14+ from devtools_testutils import get_credential
1415from azure .eventhub .extensions .checkpointstoreblobaio import BlobCheckpointStore
15- from azure .eventhub .extensions .checkpointstoreblobaio ._vendor .storage .blob import BlobServiceClient
16+ from azure .eventhub .extensions .checkpointstoreblobaio ._vendor .storage .blob . aio import BlobServiceClient
1617
1718STORAGE_ENV_KEYS = [
18- "AZURE_STORAGE_CONN_STR" ,
19- "AZURE_STORAGE_DATA_LAKE_ENABLED_CONN_STR"
19+ "AZURE_STORAGE_ACCOUNT" ,
2020]
2121
2222
23- def get_live_storage_blob_client (conn_str_env_key ):
24- try :
25- storage_connection_str = os .environ [conn_str_env_key ]
26- container_name = str (uuid .uuid4 ())
27- blob_service_client = BlobServiceClient .from_connection_string (storage_connection_str )
28- blob_service_client .create_container (container_name )
29- return storage_connection_str , container_name
30- except :
31- pytest .skip ("Storage blob client can't be created" )
32-
33-
34- def remove_live_storage_blob_client (storage_connection_str , container_str ):
35- try :
36- blob_service_client = BlobServiceClient .from_connection_string (storage_connection_str )
37- blob_service_client .delete_container (container_str )
38- except :
39- warnings .warn (UserWarning ("storage container teardown failed" ))
23+ async def get_live_storage_blob_client ( storage_account ):
24+ storage_account = "https://{}.blob.core.windows.net" .format (
25+ os .environ [storage_account ])
26+ container_name = str (uuid .uuid4 ())
27+ blob_service_client = BlobServiceClient (storage_account , get_credential (is_async = True ))
28+ await blob_service_client .create_container (container_name )
29+ return storage_account , container_name
4030
4131
42- async def _claim_and_list_ownership (connection_str , container_name ):
32+ async def _claim_and_list_ownership ( storage_account , container_name ):
4333 fully_qualified_namespace = 'test_namespace'
4434 eventhub_name = 'eventhub'
4535 consumer_group = '$default'
4636 ownership_cnt = 8
4737
48- checkpoint_store = BlobCheckpointStore .from_connection_string (connection_str , container_name )
38+ credential = get_credential (is_async = True )
39+
40+ checkpoint_store = BlobCheckpointStore (storage_account , container_name , credential = credential )
4941 async with checkpoint_store :
5042 ownership_list = await checkpoint_store .list_ownership (
5143 fully_qualified_namespace = fully_qualified_namespace ,
@@ -78,13 +70,15 @@ async def _claim_and_list_ownership(connection_str, container_name):
7870 assert len (ownership_list ) == ownership_cnt
7971
8072
81- async def _update_checkpoint (connection_str , container_name ):
73+ async def _update_checkpoint ( storage_account , container_name ):
8274 fully_qualified_namespace = 'test_namespace'
8375 eventhub_name = 'eventhub'
8476 consumer_group = '$default'
8577 partition_cnt = 8
8678
87- checkpoint_store = BlobCheckpointStore .from_connection_string (connection_str , container_name )
79+ credential = get_credential (is_async = True )
80+
81+ checkpoint_store = BlobCheckpointStore (storage_account , container_name , credential = credential )
8882 async with checkpoint_store :
8983 for i in range (partition_cnt ):
9084 checkpoint = {
@@ -107,23 +101,21 @@ async def _update_checkpoint(connection_str, container_name):
107101 assert checkpoint ['sequence_number' ] == 20
108102
109103
110- @pytest .mark .parametrize ("conn_str_env_key" , STORAGE_ENV_KEYS )
111- @pytest .mark .liveTest
112- def test_claim_and_list_ownership (conn_str_env_key ):
113- storage_connection_str , container_name = get_live_storage_blob_client (conn_str_env_key )
114- try :
115- loop = asyncio .get_event_loop ()
116- loop .run_until_complete (_claim_and_list_ownership (storage_connection_str , container_name ))
117- finally :
118- remove_live_storage_blob_client (storage_connection_str , container_name )
119-
120-
121- @pytest .mark .parametrize ("conn_str_env_key" , STORAGE_ENV_KEYS )
122- @pytest .mark .liveTest
123- def test_update_checkpoint (conn_str_env_key ):
124- storage_connection_str , container_name = get_live_storage_blob_client (conn_str_env_key )
125- try :
126- loop = asyncio .get_event_loop ()
127- loop .run_until_complete (_update_checkpoint (storage_connection_str , container_name ))
128- finally :
129- remove_live_storage_blob_client (storage_connection_str , container_name )
104+ @pytest .mark .parametrize ("storage_account" , STORAGE_ENV_KEYS )
105+ @pytest .mark .live_test_only
106+ @pytest .mark .asyncio
107+ async def test_claim_and_list_ownership_async ( storage_account ):
108+ storage_account , container_name = await get_live_storage_blob_client (storage_account )
109+ await _claim_and_list_ownership (storage_account , container_name )
110+ blob_service_client = BlobServiceClient (storage_account , credential = get_credential (is_async = True ))
111+ blob_service_client .delete_container (container_name )
112+
113+
114+ @pytest .mark .parametrize ("storage_account" , STORAGE_ENV_KEYS )
115+ @pytest .mark .live_test_only
116+ @pytest .mark .asyncio
117+ async def test_update_checkpoint_async ( storage_account ):
118+ storage_account , container_name = await get_live_storage_blob_client (storage_account )
119+ await _update_checkpoint (storage_account , container_name )
120+ blob_service_client = BlobServiceClient (storage_account , credential = get_credential (is_async = True ))
121+ blob_service_client .delete_container (container_name )
0 commit comments