2222import time
2323from datetime import datetime
2424from typing import Any , AsyncGenerator , Dict , Generator , Tuple , cast
25- from unittest import SkipTest , TestCase
25+ from unittest import SkipTest
2626from unittest .mock import AsyncMock , Mock
2727
2828import pytest_asyncio
5353 ELASTICSEARCH_URL = "http://localhost:9200"
5454
5555
56- def get_test_client (wait : bool = True , ** kwargs : Any ) -> Elasticsearch :
56+ def get_test_client (url , wait : bool = True , ** kwargs : Any ) -> Elasticsearch :
5757 # construct kwargs from the environment
5858 kw : Dict [str , Any ] = {"request_timeout" : 30 }
5959
6060 if "PYTHON_CONNECTION_CLASS" in os .environ :
6161 kw ["node_class" ] = os .environ ["PYTHON_CONNECTION_CLASS" ]
6262
6363 kw .update (kwargs )
64- client = Elasticsearch (ELASTICSEARCH_URL , ** kw )
64+ client = Elasticsearch (url , ** kw )
65+ return client
6566
6667 # wait for yellow status
67- for tries_left in range (100 if wait else 1 , 0 , - 1 ):
68- try :
69- client .cluster .health (wait_for_status = "yellow" )
70- return client
71- except ConnectionError :
72- if wait and tries_left == 1 :
73- raise
74- time .sleep (0.1 )
75-
76- raise SkipTest ("Elasticsearch failed to start." )
77-
78-
79- async def get_async_test_client (wait : bool = True , ** kwargs : Any ) -> AsyncElasticsearch :
68+ # for tries_left in range(100 if wait else 1, 0, -1):
69+ # try:
70+ # client.cluster.health(wait_for_status="yellow")
71+ # return client
72+ # except ConnectionError:
73+ # if wait and tries_left == 1:
74+ # raise
75+ # time.sleep(0.1)
76+ #
77+ # raise SkipTest("Elasticsearch failed to start.")
78+
79+
80+ async def get_async_test_client (
81+ url , wait : bool = True , ** kwargs : Any
82+ ) -> AsyncElasticsearch :
8083 # construct kwargs from the environment
8184 kw : Dict [str , Any ] = {"request_timeout" : 30 }
82-
83- if "PYTHON_CONNECTION_CLASS" in os .environ :
84- kw ["node_class" ] = os .environ ["PYTHON_CONNECTION_CLASS" ]
85-
8685 kw .update (kwargs )
87- client = AsyncElasticsearch (ELASTICSEARCH_URL , ** kw )
88-
89- # wait for yellow status
90- for tries_left in range (100 if wait else 1 , 0 , - 1 ):
91- try :
92- await client .cluster .health (wait_for_status = "yellow" )
93- return client
94- except ConnectionError :
95- if wait and tries_left == 1 :
96- raise
97- await asyncio .sleep (0.1 )
98-
86+ client = AsyncElasticsearch (url , ** kw )
87+ yield client
9988 await client .close ()
100- raise SkipTest ("Elasticsearch failed to start." )
10189
102-
103- class ElasticsearchTestCase (TestCase ):
104- client : Elasticsearch
105-
106- @staticmethod
107- def _get_client () -> Elasticsearch :
108- return get_test_client ()
109-
110- @classmethod
111- def setup_class (cls ) -> None :
112- cls .client = cls ._get_client ()
113-
114- def teardown_method (self , _ : Any ) -> None :
115- # Hidden indices expanded in wildcards in ES 7.7
116- expand_wildcards = ["open" , "closed" ]
117- if self .es_version () >= (7 , 7 ):
118- expand_wildcards .append ("hidden" )
119-
120- self .client .indices .delete_data_stream (
121- name = "*" , expand_wildcards = expand_wildcards
122- )
123- self .client .indices .delete (index = "*" , expand_wildcards = expand_wildcards )
124- self .client .indices .delete_template (name = "*" )
125- self .client .indices .delete_index_template (name = "*" )
126-
127- def es_version (self ) -> Tuple [int , ...]:
128- if not hasattr (self , "_es_version" ):
129- self ._es_version = _get_version (self .client .info ()["version" ]["number" ])
130- return self ._es_version
90+ # wait for yellow status
91+ # for tries_left in range(100 if wait else 1, 0, -1):
92+ # try:
93+ # await client.cluster.health(wait_for_status="yellow")
94+ # return client
95+ # except ConnectionError:
96+ # if wait and tries_left == 1:
97+ # raise
98+ # await asyncio.sleep(0.1)
99+ #
100+ # await client.close()
101+ # raise SkipTest("Elasticsearch failed to start.")
131102
132103
133104def _get_version (version_string : str ) -> Tuple [int , ...]:
@@ -138,19 +109,23 @@ def _get_version(version_string: str) -> Tuple[int, ...]:
138109
139110
140111@fixture (scope = "session" )
141- def client () -> Elasticsearch :
112+ def client (elasticsearch_url ) -> Elasticsearch :
142113 try :
143- connection = get_test_client (wait = "WAIT_FOR_ES" in os .environ )
114+ connection = get_test_client (
115+ elasticsearch_url , wait = "WAIT_FOR_ES" in os .environ
116+ )
144117 add_connection ("default" , connection )
145118 return connection
146119 except SkipTest :
147120 skip ()
148121
149122
150123@pytest_asyncio .fixture
151- async def async_client () -> AsyncGenerator [AsyncElasticsearch , None ]:
124+ async def async_client (elasticsearch_url ) -> AsyncGenerator [AsyncElasticsearch , None ]:
152125 try :
153- connection = await get_async_test_client (wait = "WAIT_FOR_ES" in os .environ )
126+ connection = await get_async_test_client (
127+ elasticsearch_url , wait = "WAIT_FOR_ES" in os .environ
128+ )
154129 add_async_connection ("default" , connection )
155130 yield connection
156131 await connection .close ()
0 commit comments