Skip to content

Commit 08246bd

Browse files
committed
testutils: Refactor test harness for sokets; implement tcp_client()
1 parent 11520d5 commit 08246bd

File tree

1 file changed

+91
-69
lines changed

1 file changed

+91
-69
lines changed

uvloop/_testbase.py

Lines changed: 91 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -70,31 +70,35 @@ def new_loop(self):
7070

7171

7272
###############################################################################
73-
## Socket Testing Utilities
73+
# Socket Testing Utilities
7474
###############################################################################
7575

7676

77-
def unix_server(server_prog, *,
78-
addr=None,
79-
timeout=1,
80-
backlog=1,
81-
max_clients=1):
77+
def tcp_server(server_prog, *,
78+
family=socket.AF_INET,
79+
addr=None,
80+
timeout=5,
81+
backlog=1,
82+
max_clients=10):
83+
84+
if addr is None:
85+
if family == socket.AF_UNIX:
86+
with tempfile.NamedTemporaryFile() as tmp:
87+
addr = tmp.name
88+
else:
89+
addr = ('127.0.0.1', 0)
8290

8391
if not inspect.isgeneratorfunction(server_prog):
8492
raise TypeError('server_prog: a generator function was expected')
8593

86-
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
94+
sock = socket.socket(family, socket.SOCK_STREAM)
8795

8896
if timeout is None:
8997
raise RuntimeError('timeout is required')
9098
if timeout <= 0:
9199
raise RuntimeError('only blocking sockets are supported')
92100
sock.settimeout(timeout)
93101

94-
if addr is None:
95-
with tempfile.NamedTemporaryFile() as tmp:
96-
addr = tmp.name
97-
98102
try:
99103
sock.bind(addr)
100104
sock.listen(backlog)
@@ -106,15 +110,12 @@ def unix_server(server_prog, *,
106110
return srv
107111

108112

109-
def tcp_server(server_prog, *,
113+
def tcp_client(client_prog,
110114
family=socket.AF_INET,
111-
addr=('127.0.0.1', 0),
112-
timeout=1,
113-
backlog=1,
114-
max_clients=1):
115+
timeout=10):
115116

116-
if not inspect.isgeneratorfunction(server_prog):
117-
raise TypeError('server_prog: a generator function was expected')
117+
if not inspect.isgeneratorfunction(client_prog):
118+
raise TypeError('client_prog: a generator function was expected')
118119

119120
sock = socket.socket(family, socket.SOCK_STREAM)
120121

@@ -124,18 +125,59 @@ def tcp_server(server_prog, *,
124125
raise RuntimeError('only blocking sockets are supported')
125126
sock.settimeout(timeout)
126127

127-
try:
128-
sock.bind(addr)
129-
sock.listen(backlog)
130-
except OSError as ex:
131-
sock.close()
132-
raise ex
133-
134-
srv = Server(sock, server_prog, timeout, max_clients)
128+
srv = Client(sock, client_prog, timeout)
135129
return srv
136130

137131

138-
class Server(threading.Thread):
132+
class _Runner:
133+
def _iterate(self, prog, sock):
134+
last_val = None
135+
while self._active:
136+
try:
137+
command = prog.send(last_val)
138+
except StopIteration:
139+
return
140+
141+
if not isinstance(command, _Command):
142+
raise TypeError(
143+
'client_prog yielded invalid command {!r}'.format(command))
144+
145+
command_res = command._run(sock)
146+
assert isinstance(command_res, tuple) and len(command_res) == 2
147+
148+
last_val = command_res[1]
149+
sock = command_res[0]
150+
151+
def stop(self):
152+
self._active = False
153+
self.join()
154+
155+
def __enter__(self):
156+
self.start()
157+
return self
158+
159+
def __exit__(self, *exc):
160+
self.stop()
161+
162+
163+
class Client(_Runner, threading.Thread):
164+
165+
def __init__(self, sock, prog, timeout):
166+
threading.Thread.__init__(self, None, None, 'test-client')
167+
self.daemon = True
168+
169+
self._timeout = timeout
170+
self._sock = sock
171+
self._active = True
172+
self._prog = prog
173+
174+
def run(self):
175+
prog = self._prog()
176+
sock = self._sock
177+
self._iterate(prog, sock)
178+
179+
180+
class Server(_Runner, threading.Thread):
139181

140182
def __init__(self, sock, prog, timeout, max_clients):
141183
threading.Thread.__init__(self, None, None, 'test-server')
@@ -173,53 +215,20 @@ def run(self):
173215

174216
def _handle_client(self, sock):
175217
prog = self._prog()
176-
177-
last_val = None
178-
while self._active:
179-
try:
180-
command = prog.send(last_val)
181-
except StopIteration:
182-
self._finished_clients += 1
183-
return
184-
185-
if not isinstance(command, Command):
186-
raise TypeError(
187-
'server_prog yielded invalid command {!r}'.format(command))
188-
189-
command_res = command._run(sock)
190-
assert isinstance(command_res, tuple) and len(command_res) == 2
191-
192-
last_val = command_res[1]
193-
sock = command_res[0]
218+
self._iterate(prog, sock)
194219

195220
@property
196221
def addr(self):
197222
return self._sock.getsockname()
198223

199-
def stop(self):
200-
self._active = False
201-
self.join()
202-
203-
if self._finished_clients != self._clients:
204-
raise AssertionError(
205-
'not all clients are finished: {!r}'.format(
206-
self._clients - self._finished_clients))
207-
208-
def __enter__(self):
209-
self.start()
210-
return self
211-
212-
def __exit__(self, *exc):
213-
self.stop()
214-
215224

216-
class Command:
225+
class _Command:
217226

218227
def _run(self, sock):
219228
raise NotImplementedError
220229

221230

222-
class write(Command):
231+
class write(_Command):
223232

224233
def __init__(self, data:bytes):
225234
self._data = data
@@ -229,13 +238,22 @@ def _run(self, sock):
229238
return sock, None
230239

231240

232-
class close(Command):
241+
class connect(_Command):
242+
def __init__(self, addr):
243+
self._addr = addr
244+
245+
def _run(self, sock):
246+
sock.connect(self._addr)
247+
return sock, None
248+
249+
250+
class close(_Command):
233251
def _run(self, sock):
234252
sock.close()
235253
return sock, None
236254

237255

238-
class read(Command):
256+
class read(_Command):
239257

240258
def __init__(self, nbytes):
241259
self._nbytes = nbytes
@@ -260,23 +278,27 @@ def _run(self, sock):
260278
return sock, data
261279

262280

263-
class starttls(Command):
281+
class starttls(_Command):
264282

265283
def __init__(self, ssl_context, *,
266284
server_side=False,
267-
server_hostname=None):
285+
server_hostname=None,
286+
do_handshake_on_connect=True):
268287

269288
assert isinstance(ssl_context, ssl.SSLContext)
270289
self._ctx = ssl_context
271290

272291
self._server_side = server_side
273292
self._server_hostname = server_hostname
293+
self._do_handshake_on_connect = do_handshake_on_connect
274294

275295
def _run(self, sock):
276296
ssl_sock = self._ctx.wrap_socket(
277297
sock, server_side=self._server_side,
278-
server_hostname=self._server_hostname)
298+
server_hostname=self._server_hostname,
299+
do_handshake_on_connect=self._do_handshake_on_connect)
279300

280-
ssl_sock.do_handshake()
301+
if self._server_side:
302+
ssl_sock.do_handshake()
281303

282304
return ssl_sock, None

0 commit comments

Comments
 (0)