21
21
import asyncio
22
22
import time
23
23
import unittest
24
- from pathlib import Path
25
- from typing import Optional , Sequence
24
+ from typing import Sequence
26
25
27
26
from derived .thrift_clients import DerivedTestingService
28
27
from derived .thrift_services import DerivedTestingServiceInterface
35
34
from testing .thrift_services import TestingServiceInterface
36
35
from testing .thrift_types import Color , easy , SimpleError
37
36
from thrift .lib .python .test .event_handlers .helper import ThrowHelper , ThrowHelperHandler
37
+ from thrift .lib .python .test .test_server import TestServer
38
38
from thrift .py3 .server import get_context , SocketAddress
39
39
from thrift .python .client import get_client
40
40
from thrift .python .common import Priority , RpcOptions
41
41
from thrift .python .exceptions import ApplicationError
42
- from thrift .python .server import ServiceInterface , ThriftServer
42
+ from thrift .python .server import ServiceInterface
43
43
44
44
45
45
class Handler (TestingServiceInterface ):
@@ -105,31 +105,15 @@ async def derived_pick_a_color(self, color: Color) -> Color:
105
105
return color
106
106
107
107
108
- class TestServer :
109
- server : ThriftServer
110
- serve_task : asyncio .Task
108
+ def local_server (handler : ServiceInterface | None = None ) -> TestServer :
109
+ if handler is None :
110
+ handler = Handler ()
111
+ return TestServer (handler = handler , ip = "::1" )
111
112
112
- def __init__ (
113
- self ,
114
- ip : Optional [str ] = None ,
115
- path : Optional ["Path" ] = None ,
116
- handler : ServiceInterface = Handler (), # noqa: B008
117
- ) -> None :
118
- self .server = ThriftServer (handler , ip = ip , path = path )
119
- # pyre-fixme[8]: The initialization below eliminates
120
- # the pyre[13] error, but results in
121
- # pyre[4] and pyre[8] errors.
122
- # __aenter__ sets the required value.
123
- self .serve_task = None
124
113
125
- async def __aenter__ (self ) -> SocketAddress :
126
- self .serve_task = asyncio .get_event_loop ().create_task (self .server .serve ())
127
- return await self .server .get_address ()
128
-
129
- # pyre-fixme[2]: Parameter must be annotated.
130
- async def __aexit__ (self , * exc_info ) -> None :
131
- self .server .stop ()
132
- await self .serve_task
114
+ def default_server () -> TestServer :
115
+ # note in this case, port is set to 0
116
+ return TestServer (handler = Handler ())
133
117
134
118
135
119
class ClientServerTests (unittest .IsolatedAsyncioTestCase ):
@@ -138,7 +122,7 @@ class ClientServerTests(unittest.IsolatedAsyncioTestCase):
138
122
"""
139
123
140
124
async def test_get_context (self ) -> None :
141
- async with TestServer ( ip = "::1" ) as sa :
125
+ async with local_server ( ) as sa :
142
126
ip , port = sa .ip , sa .port
143
127
assert ip and port
144
128
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -166,7 +150,7 @@ async def test_get_context(self) -> None:
166
150
await handler .getName ()
167
151
168
152
async def test_rpc_headers (self ) -> None :
169
- async with TestServer ( ip = "::1" ) as sa :
153
+ async with local_server ( ) as sa :
170
154
ip , port = sa .ip , sa .port
171
155
assert ip and port
172
156
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -176,7 +160,7 @@ async def test_rpc_headers(self) -> None:
176
160
self .assertIn ("from server" , options .read_headers )
177
161
178
162
async def test_server_localhost (self ) -> None :
179
- async with TestServer ( ip = "::1" ) as sa :
163
+ async with local_server ( ) as sa :
180
164
ip , port = sa .ip , sa .port
181
165
assert ip and port
182
166
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -187,7 +171,7 @@ async def test_server_localhost(self) -> None:
187
171
await client .takes_a_list ([])
188
172
189
173
async def test_no_client_aexit (self ) -> None :
190
- async with TestServer () as sa :
174
+ async with default_server () as sa :
191
175
ip , port = sa .ip , sa .port
192
176
assert ip and port
193
177
client = get_client (TestingService , host = ip , port = port )
@@ -202,7 +186,7 @@ async def test_client_aexit_no_await(self) -> None:
202
186
This actually handles the case if __aexit__ is not awaited
203
187
"""
204
188
205
- async with TestServer () as sa :
189
+ async with default_server () as sa :
206
190
ip , port = sa .ip , sa .port
207
191
assert ip and port
208
192
client = get_client (TestingService , host = ip , port = port )
@@ -218,7 +202,7 @@ async def test_no_client_no_aenter(self) -> None:
218
202
This covers if aenter was canceled since those two are the same really
219
203
"""
220
204
221
- async with TestServer () as sa :
205
+ async with default_server () as sa :
222
206
ip , port = sa .ip , sa .port
223
207
assert ip and port
224
208
get_client (TestingService , host = ip , port = port )
@@ -230,7 +214,7 @@ async def test_derived_service(self) -> None:
230
214
This tests calling methods from a derived service
231
215
"""
232
216
233
- async with TestServer (handler = DerivedHandler ()) as sa :
217
+ async with local_server (handler = DerivedHandler ()) as sa :
234
218
ip , port = sa .ip , sa .port
235
219
assert ip and port
236
220
async with get_client (
@@ -244,7 +228,7 @@ async def test_derived_service(self) -> None:
244
228
)
245
229
246
230
async def test_renamed_func (self ) -> None :
247
- async with TestServer ( ip = "::1" ) as sa :
231
+ async with local_server ( ) as sa :
248
232
ip , port = sa .ip , sa .port
249
233
assert ip and port
250
234
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -303,7 +287,7 @@ async def getName(self) -> str:
303
287
cancelledMessage
304
288
) # Pretend that this is some await call that gets cancelled
305
289
306
- async with TestServer (handler = CancelHandler (), ip = "::1" ) as sa :
290
+ async with local_server (handler = CancelHandler ()) as sa :
307
291
ip , port = sa .ip , sa .port
308
292
assert ip and port
309
293
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -324,7 +308,7 @@ class ErrorHandler(TestingServiceInterface):
324
308
async def getName (self ) -> str :
325
309
raise Exception (errMessage )
326
310
327
- async with TestServer (handler = ErrorHandler (), ip = "::1" ) as sa :
311
+ async with local_server (handler = ErrorHandler ()) as sa :
328
312
ip , port = sa .ip , sa .port
329
313
assert ip and port
330
314
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -336,7 +320,7 @@ async def getName(self) -> str:
336
320
)
337
321
338
322
async def test_request_with_default_rpc_options (self ) -> None :
339
- async with TestServer ( ip = "::1" ) as sa :
323
+ async with local_server ( ) as sa :
340
324
ip , port = sa .ip , sa .port
341
325
assert ip and port
342
326
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -346,7 +330,7 @@ async def test_request_with_default_rpc_options(self) -> None:
346
330
self .assertEqual (Priority (priority ), Priority .N_PRIORITIES )
347
331
348
332
async def test_request_with_specified_rpc_options (self ) -> None :
349
- async with TestServer ( ip = "::1" ) as sa :
333
+ async with local_server ( ) as sa :
350
334
ip , port = sa .ip , sa .port
351
335
assert ip and port
352
336
async with get_client (TestingService , host = ip , port = port ) as client :
@@ -360,7 +344,7 @@ async def test_request_with_specified_rpc_options(self) -> None:
360
344
361
345
async def test_client_event_handler_throw (self ) -> None :
362
346
for handler in ThrowHelperHandler :
363
- async with TestServer ( ip = "::1" ) as sa :
347
+ async with local_server ( ) as sa :
364
348
ip , port = sa .ip , sa .port
365
349
self .assertIsNotNone (ip )
366
350
self .assertIsNotNone (port )
@@ -406,7 +390,7 @@ class ClientStackServerTests(unittest.IsolatedAsyncioTestCase):
406
390
"""
407
391
408
392
async def test_server_localhost (self ) -> None :
409
- async with TestServer (handler = StackHandler (), ip = "::1" ) as sa :
393
+ async with local_server (handler = StackHandler ()) as sa :
410
394
ip , port = sa .ip , sa .port
411
395
assert ip and port
412
396
async with get_client (StackService , host = ip , port = port ) as client :
0 commit comments