Skip to content

Commit ab67908

Browse files
ahilgerfacebook-github-bot
authored andcommitted
consolidate TestServer into one
Summary: This concept is copy-pasted at least 3 places, soon to be 4, so let's consolidate into a proper helper lib. Reviewed By: prakashgayasen Differential Revision: D79925577 fbshipit-source-id: f8524634179b0afefdfa6439b1b78fd4810923c3
1 parent 604b8be commit ab67908

File tree

5 files changed

+98
-122
lines changed

5 files changed

+98
-122
lines changed

third-party/thrift/src/thrift/lib/python/test/binary.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from __future__ import annotations
2020

21-
import asyncio
2221
import unittest
2322
from typing import Any, cast, Type
2423

@@ -41,9 +40,9 @@
4140
from folly.iobuf import IOBuf
4241

4342
from parameterized import parameterized
44-
from thrift.py3.server import SocketAddress
43+
from thrift.lib.python.test.test_server import TestServer
4544
from thrift.python.client import get_client
46-
from thrift.python.server import ServiceInterface, ThriftServer
45+
from thrift.python.server import ServiceInterface
4746

4847

4948
class BinaryTests(unittest.TestCase):
@@ -73,6 +72,10 @@ def test_binary_union(
7372
self.assertEqual(bytes(val.iobuf_val), b"mnopqr")
7473

7574

75+
def local_server(handler: ServiceInterface) -> TestServer:
76+
return TestServer(handler=handler, ip="::1")
77+
78+
7679
def get_binary_handler_type(
7780
BinaryServiceInterface: Type[BinaryServiceInterfaceImmutable]
7881
| Type[BinaryServiceInterfaceMutable],
@@ -125,24 +128,6 @@ async def sendRecBinaryUnion(self, val: BinaryUnion) -> BinaryUnion:
125128
return cast(Type[ServiceInterface], BinaryHandler)
126129

127130

128-
class TestServer:
129-
server: ThriftServer
130-
# pyre-fixme[13]: Attribute `serve_task` is never initialized.
131-
serve_task: asyncio.Task
132-
133-
def __init__(self, *, ip: str, handler: ServiceInterface) -> None:
134-
self.server = ThriftServer(handler, ip=ip, path=None)
135-
136-
async def __aenter__(self) -> SocketAddress:
137-
self.serve_task = asyncio.get_event_loop().create_task(self.server.serve())
138-
return await self.server.get_address()
139-
140-
# pyre-fixme[2]: Parameter must be annotated.
141-
async def __aexit__(self, *exc_info) -> None:
142-
self.server.stop()
143-
await self.serve_task
144-
145-
146131
class ClientBinaryServerTests(unittest.IsolatedAsyncioTestCase):
147132
@parameterized.expand(
148133
[
@@ -172,7 +157,7 @@ async def test_send_recv(
172157
BinaryServiceInterface, Binaries, BinaryUnion
173158
)
174159
# pyre-ignore[19]: `object.__init__` expects 0 positional arguments
175-
async with TestServer(handler=BinaryHandler(self), ip="::1") as sa:
160+
async with local_server(handler=BinaryHandler(self)) as sa:
176161
ip, port = sa.ip, sa.port
177162
assert ip and port
178163
async with get_client(BinaryService, host=ip, port=port) as client:

third-party/thrift/src/thrift/lib/python/test/client_server.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
import asyncio
2222
import time
2323
import unittest
24-
from pathlib import Path
25-
from typing import Optional, Sequence
24+
from typing import Sequence
2625

2726
from derived.thrift_clients import DerivedTestingService
2827
from derived.thrift_services import DerivedTestingServiceInterface
@@ -35,11 +34,12 @@
3534
from testing.thrift_services import TestingServiceInterface
3635
from testing.thrift_types import Color, easy, SimpleError
3736
from thrift.lib.python.test.event_handlers.helper import ThrowHelper, ThrowHelperHandler
37+
from thrift.lib.python.test.test_server import TestServer
3838
from thrift.py3.server import get_context, SocketAddress
3939
from thrift.python.client import get_client
4040
from thrift.python.common import Priority, RpcOptions
4141
from thrift.python.exceptions import ApplicationError
42-
from thrift.python.server import ServiceInterface, ThriftServer
42+
from thrift.python.server import ServiceInterface
4343

4444

4545
class Handler(TestingServiceInterface):
@@ -105,31 +105,15 @@ async def derived_pick_a_color(self, color: Color) -> Color:
105105
return color
106106

107107

108-
class TestServer:
109-
server: ThriftServer
110-
serve_task: asyncio.Task
108+
def local_server(handler: ServiceInterface | None = None) -> TestServer:
109+
if handler is None:
110+
handler = Handler()
111+
return TestServer(handler=handler, ip="::1")
111112

112-
def __init__(
113-
self,
114-
ip: Optional[str] = None,
115-
path: Optional["Path"] = None,
116-
handler: ServiceInterface = Handler(), # noqa: B008
117-
) -> None:
118-
self.server = ThriftServer(handler, ip=ip, path=path)
119-
# pyre-fixme[8]: The initialization below eliminates
120-
# the pyre[13] error, but results in
121-
# pyre[4] and pyre[8] errors.
122-
# __aenter__ sets the required value.
123-
self.serve_task = None
124113

125-
async def __aenter__(self) -> SocketAddress:
126-
self.serve_task = asyncio.get_event_loop().create_task(self.server.serve())
127-
return await self.server.get_address()
128-
129-
# pyre-fixme[2]: Parameter must be annotated.
130-
async def __aexit__(self, *exc_info) -> None:
131-
self.server.stop()
132-
await self.serve_task
114+
def default_server() -> TestServer:
115+
# note in this case, port is set to 0
116+
return TestServer(handler=Handler())
133117

134118

135119
class ClientServerTests(unittest.IsolatedAsyncioTestCase):
@@ -138,7 +122,7 @@ class ClientServerTests(unittest.IsolatedAsyncioTestCase):
138122
"""
139123

140124
async def test_get_context(self) -> None:
141-
async with TestServer(ip="::1") as sa:
125+
async with local_server() as sa:
142126
ip, port = sa.ip, sa.port
143127
assert ip and port
144128
async with get_client(TestingService, host=ip, port=port) as client:
@@ -166,7 +150,7 @@ async def test_get_context(self) -> None:
166150
await handler.getName()
167151

168152
async def test_rpc_headers(self) -> None:
169-
async with TestServer(ip="::1") as sa:
153+
async with local_server() as sa:
170154
ip, port = sa.ip, sa.port
171155
assert ip and port
172156
async with get_client(TestingService, host=ip, port=port) as client:
@@ -176,7 +160,7 @@ async def test_rpc_headers(self) -> None:
176160
self.assertIn("from server", options.read_headers)
177161

178162
async def test_server_localhost(self) -> None:
179-
async with TestServer(ip="::1") as sa:
163+
async with local_server() as sa:
180164
ip, port = sa.ip, sa.port
181165
assert ip and port
182166
async with get_client(TestingService, host=ip, port=port) as client:
@@ -187,7 +171,7 @@ async def test_server_localhost(self) -> None:
187171
await client.takes_a_list([])
188172

189173
async def test_no_client_aexit(self) -> None:
190-
async with TestServer() as sa:
174+
async with default_server() as sa:
191175
ip, port = sa.ip, sa.port
192176
assert ip and port
193177
client = get_client(TestingService, host=ip, port=port)
@@ -202,7 +186,7 @@ async def test_client_aexit_no_await(self) -> None:
202186
This actually handles the case if __aexit__ is not awaited
203187
"""
204188

205-
async with TestServer() as sa:
189+
async with default_server() as sa:
206190
ip, port = sa.ip, sa.port
207191
assert ip and port
208192
client = get_client(TestingService, host=ip, port=port)
@@ -218,7 +202,7 @@ async def test_no_client_no_aenter(self) -> None:
218202
This covers if aenter was canceled since those two are the same really
219203
"""
220204

221-
async with TestServer() as sa:
205+
async with default_server() as sa:
222206
ip, port = sa.ip, sa.port
223207
assert ip and port
224208
get_client(TestingService, host=ip, port=port)
@@ -230,7 +214,7 @@ async def test_derived_service(self) -> None:
230214
This tests calling methods from a derived service
231215
"""
232216

233-
async with TestServer(handler=DerivedHandler()) as sa:
217+
async with local_server(handler=DerivedHandler()) as sa:
234218
ip, port = sa.ip, sa.port
235219
assert ip and port
236220
async with get_client(
@@ -244,7 +228,7 @@ async def test_derived_service(self) -> None:
244228
)
245229

246230
async def test_renamed_func(self) -> None:
247-
async with TestServer(ip="::1") as sa:
231+
async with local_server() as sa:
248232
ip, port = sa.ip, sa.port
249233
assert ip and port
250234
async with get_client(TestingService, host=ip, port=port) as client:
@@ -303,7 +287,7 @@ async def getName(self) -> str:
303287
cancelledMessage
304288
) # Pretend that this is some await call that gets cancelled
305289

306-
async with TestServer(handler=CancelHandler(), ip="::1") as sa:
290+
async with local_server(handler=CancelHandler()) as sa:
307291
ip, port = sa.ip, sa.port
308292
assert ip and port
309293
async with get_client(TestingService, host=ip, port=port) as client:
@@ -324,7 +308,7 @@ class ErrorHandler(TestingServiceInterface):
324308
async def getName(self) -> str:
325309
raise Exception(errMessage)
326310

327-
async with TestServer(handler=ErrorHandler(), ip="::1") as sa:
311+
async with local_server(handler=ErrorHandler()) as sa:
328312
ip, port = sa.ip, sa.port
329313
assert ip and port
330314
async with get_client(TestingService, host=ip, port=port) as client:
@@ -336,7 +320,7 @@ async def getName(self) -> str:
336320
)
337321

338322
async def test_request_with_default_rpc_options(self) -> None:
339-
async with TestServer(ip="::1") as sa:
323+
async with local_server() as sa:
340324
ip, port = sa.ip, sa.port
341325
assert ip and port
342326
async with get_client(TestingService, host=ip, port=port) as client:
@@ -346,7 +330,7 @@ async def test_request_with_default_rpc_options(self) -> None:
346330
self.assertEqual(Priority(priority), Priority.N_PRIORITIES)
347331

348332
async def test_request_with_specified_rpc_options(self) -> None:
349-
async with TestServer(ip="::1") as sa:
333+
async with local_server() as sa:
350334
ip, port = sa.ip, sa.port
351335
assert ip and port
352336
async with get_client(TestingService, host=ip, port=port) as client:
@@ -360,7 +344,7 @@ async def test_request_with_specified_rpc_options(self) -> None:
360344

361345
async def test_client_event_handler_throw(self) -> None:
362346
for handler in ThrowHelperHandler:
363-
async with TestServer(ip="::1") as sa:
347+
async with local_server() as sa:
364348
ip, port = sa.ip, sa.port
365349
self.assertIsNotNone(ip)
366350
self.assertIsNotNone(port)
@@ -406,7 +390,7 @@ class ClientStackServerTests(unittest.IsolatedAsyncioTestCase):
406390
"""
407391

408392
async def test_server_localhost(self) -> None:
409-
async with TestServer(handler=StackHandler(), ip="::1") as sa:
393+
async with local_server(handler=StackHandler()) as sa:
410394
ip, port = sa.ip, sa.port
411395
assert ip and port
412396
async with get_client(StackService, host=ip, port=port) as client:

third-party/thrift/src/thrift/lib/python/test/metadata_response/metadata_response_test.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,23 @@
1818

1919
from __future__ import annotations
2020

21-
import asyncio
2221
import unittest
23-
from pathlib import Path
24-
from typing import Optional, Sequence
22+
from typing import Sequence
2523

2624
from apache.thrift.metadata.thrift_types import ThriftServiceMetadataResponse
2725
from testing.thrift_services import TestingServiceInterface
2826
from testing.thrift_types import Color, easy, SimpleError
29-
from thrift.py3.server import get_context, SocketAddress
27+
from thrift.lib.python.test.test_server import TestServer
28+
from thrift.py3.server import get_context
3029
from thrift.python.serializer import deserialize, Protocol
31-
from thrift.python.server import ServiceInterface, ThriftServer
3230

3331
from .metadata_response import get_serialized_cpp_metadata
3432

3533

34+
def local_server() -> TestServer:
35+
return TestServer(handler=Handler(), ip="::1")
36+
37+
3638
class Handler(TestingServiceInterface):
3739
async def invert(self, value: bool) -> bool:
3840
ctx = get_context()
@@ -84,36 +86,13 @@ async def renamed_func(self, ret: bool) -> bool:
8486
return ret
8587

8688

87-
class TestServer:
88-
server: ThriftServer
89-
# pyre-fixme[13]: Attribute `serve_task` is never initialized.
90-
serve_task: asyncio.Task
91-
92-
def __init__(
93-
self,
94-
ip: Optional[str] = None,
95-
path: Optional["Path"] = None,
96-
handler: ServiceInterface = Handler(), # noqa: B008
97-
) -> None:
98-
self.server = ThriftServer(handler, ip=ip, path=path)
99-
100-
async def __aenter__(self) -> SocketAddress:
101-
self.serve_task = asyncio.get_event_loop().create_task(self.server.serve())
102-
return await self.server.get_address()
103-
104-
# pyre-fixme[2]: Parameter must be annotated.
105-
async def __aexit__(self, *exc_info) -> None:
106-
self.server.stop()
107-
await self.serve_task
108-
109-
11089
class MetadataResponseTest(unittest.IsolatedAsyncioTestCase):
11190
"""
11291
These are tests where a client and server talk to each other
11392
"""
11493

11594
async def test_server_localhost(self) -> None:
116-
server = TestServer(ip="::1")
95+
server = local_server()
11796
async with server as _:
11897
metadata_cpp = deserialize(
11998
ThriftServiceMetadataResponse,

0 commit comments

Comments
 (0)