2222import time
2323from datetime import datetime
2424from typing import Any , AsyncGenerator , Dict , Generator , Tuple , cast
25- from unittest import SkipTest , TestCase
25+ from unittest import SkipTest
2626from unittest .mock import AsyncMock , Mock
2727
2828import pytest_asyncio
4646 create_flat_git_index ,
4747 create_git_index ,
4848)
49+ from ..utils import CA_CERTS
4950
50- if "ELASTICSEARCH_URL" in os .environ :
51- ELASTICSEARCH_URL = os .environ ["ELASTICSEARCH_URL" ]
52- else :
53- ELASTICSEARCH_URL = "http://localhost:9200"
5451
55-
56- def get_test_client (wait : bool = True , ** kwargs : Any ) -> Elasticsearch :
52+ def get_test_client (elasticsearch_url , wait : bool = True , ** kwargs : Any ) -> Elasticsearch :
5753 # construct kwargs from the environment
5854 kw : Dict [str , Any ] = {"request_timeout" : 30 }
5955
56+ if elasticsearch_url .startswith ("https://" ):
57+ kw ["ca_certs" ] = CA_CERTS
58+
6059 if "PYTHON_CONNECTION_CLASS" in os .environ :
6160 kw ["node_class" ] = os .environ ["PYTHON_CONNECTION_CLASS" ]
6261
6362 kw .update (kwargs )
64- client = Elasticsearch (ELASTICSEARCH_URL , ** kw )
63+ client = Elasticsearch (elasticsearch_url , ** kw )
6564
6665 # wait for yellow status
6766 for tries_left in range (100 if wait else 1 , 0 , - 1 ):
@@ -76,15 +75,17 @@ def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch:
7675 raise SkipTest ("Elasticsearch failed to start." )
7776
7877
79- async def get_async_test_client (wait : bool = True , ** kwargs : Any ) -> AsyncElasticsearch :
78+ async def get_async_test_client (
79+ elasticsearch_url , wait : bool = True , ** kwargs : Any
80+ ) -> AsyncElasticsearch :
8081 # construct kwargs from the environment
8182 kw : Dict [str , Any ] = {"request_timeout" : 30 }
8283
83- if "PYTHON_CONNECTION_CLASS" in os . environ :
84- kw ["node_class " ] = os . environ [ "PYTHON_CONNECTION_CLASS" ]
84+ if elasticsearch_url . startswith ( "https://" ) :
85+ kw ["ca_certs " ] = CA_CERTS
8586
8687 kw .update (kwargs )
87- client = AsyncElasticsearch (ELASTICSEARCH_URL , ** kw )
88+ client = AsyncElasticsearch (elasticsearch_url , ** kw )
8889
8990 # wait for yellow status
9091 for tries_left in range (100 if wait else 1 , 0 , - 1 ):
@@ -100,36 +101,6 @@ async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasti
100101 raise SkipTest ("Elasticsearch failed to start." )
101102
102103
103- class ElasticsearchTestCase (TestCase ):
104- client : Elasticsearch
105-
106- @staticmethod
107- def _get_client () -> Elasticsearch :
108- return get_test_client ()
109-
110- @classmethod
111- def setup_class (cls ) -> None :
112- cls .client = cls ._get_client ()
113-
114- def teardown_method (self , _ : Any ) -> None :
115- # Hidden indices expanded in wildcards in ES 7.7
116- expand_wildcards = ["open" , "closed" ]
117- if self .es_version () >= (7 , 7 ):
118- expand_wildcards .append ("hidden" )
119-
120- self .client .indices .delete_data_stream (
121- name = "*" , expand_wildcards = expand_wildcards
122- )
123- self .client .indices .delete (index = "*" , expand_wildcards = expand_wildcards )
124- self .client .indices .delete_template (name = "*" )
125- self .client .indices .delete_index_template (name = "*" )
126-
127- def es_version (self ) -> Tuple [int , ...]:
128- if not hasattr (self , "_es_version" ):
129- self ._es_version = _get_version (self .client .info ()["version" ]["number" ])
130- return self ._es_version
131-
132-
133104def _get_version (version_string : str ) -> Tuple [int , ...]:
134105 if "." not in version_string :
135106 return ()
@@ -138,19 +109,23 @@ def _get_version(version_string: str) -> Tuple[int, ...]:
138109
139110
140111@fixture (scope = "session" )
141- def client () -> Elasticsearch :
112+ def client (elasticsearch_url ) -> Elasticsearch :
142113 try :
143- connection = get_test_client (wait = "WAIT_FOR_ES" in os .environ )
114+ connection = get_test_client (
115+ elasticsearch_url , wait = "WAIT_FOR_ES" in os .environ
116+ )
144117 add_connection ("default" , connection )
145118 return connection
146119 except SkipTest :
147120 skip ()
148121
149122
150123@pytest_asyncio .fixture
151- async def async_client () -> AsyncGenerator [AsyncElasticsearch , None ]:
124+ async def async_client (elasticsearch_url ) -> AsyncGenerator [AsyncElasticsearch , None ]:
152125 try :
153- connection = await get_async_test_client (wait = "WAIT_FOR_ES" in os .environ )
126+ connection = await get_async_test_client (
127+ elasticsearch_url , wait = "WAIT_FOR_ES" in os .environ
128+ )
154129 add_async_connection ("default" , connection )
155130 yield connection
156131 await connection .close ()
0 commit comments