@@ -70,31 +70,35 @@ def new_loop(self):
70
70
71
71
72
72
###############################################################################
73
- ## Socket Testing Utilities
73
+ # Socket Testing Utilities
74
74
###############################################################################
75
75
76
76
77
- def unix_server (server_prog , * ,
78
- addr = None ,
79
- timeout = 1 ,
80
- backlog = 1 ,
81
- max_clients = 1 ):
77
+ def tcp_server (server_prog , * ,
78
+ family = socket .AF_INET ,
79
+ addr = None ,
80
+ timeout = 5 ,
81
+ backlog = 1 ,
82
+ max_clients = 10 ):
83
+
84
+ if addr is None :
85
+ if family == socket .AF_UNIX :
86
+ with tempfile .NamedTemporaryFile () as tmp :
87
+ addr = tmp .name
88
+ else :
89
+ addr = ('127.0.0.1' , 0 )
82
90
83
91
if not inspect .isgeneratorfunction (server_prog ):
84
92
raise TypeError ('server_prog: a generator function was expected' )
85
93
86
- sock = socket .socket (socket . AF_UNIX , socket .SOCK_STREAM )
94
+ sock = socket .socket (family , socket .SOCK_STREAM )
87
95
88
96
if timeout is None :
89
97
raise RuntimeError ('timeout is required' )
90
98
if timeout <= 0 :
91
99
raise RuntimeError ('only blocking sockets are supported' )
92
100
sock .settimeout (timeout )
93
101
94
- if addr is None :
95
- with tempfile .NamedTemporaryFile () as tmp :
96
- addr = tmp .name
97
-
98
102
try :
99
103
sock .bind (addr )
100
104
sock .listen (backlog )
@@ -106,15 +110,12 @@ def unix_server(server_prog, *,
106
110
return srv
107
111
108
112
109
- def tcp_server ( server_prog , * ,
113
+ def tcp_client ( client_prog ,
110
114
family = socket .AF_INET ,
111
- addr = ('127.0.0.1' , 0 ),
112
- timeout = 1 ,
113
- backlog = 1 ,
114
- max_clients = 1 ):
115
+ timeout = 10 ):
115
116
116
- if not inspect .isgeneratorfunction (server_prog ):
117
- raise TypeError ('server_prog : a generator function was expected' )
117
+ if not inspect .isgeneratorfunction (client_prog ):
118
+ raise TypeError ('client_prog : a generator function was expected' )
118
119
119
120
sock = socket .socket (family , socket .SOCK_STREAM )
120
121
@@ -124,18 +125,59 @@ def tcp_server(server_prog, *,
124
125
raise RuntimeError ('only blocking sockets are supported' )
125
126
sock .settimeout (timeout )
126
127
127
- try :
128
- sock .bind (addr )
129
- sock .listen (backlog )
130
- except OSError as ex :
131
- sock .close ()
132
- raise ex
133
-
134
- srv = Server (sock , server_prog , timeout , max_clients )
128
+ srv = Client (sock , client_prog , timeout )
135
129
return srv
136
130
137
131
138
- class Server (threading .Thread ):
132
+ class _Runner :
133
+ def _iterate (self , prog , sock ):
134
+ last_val = None
135
+ while self ._active :
136
+ try :
137
+ command = prog .send (last_val )
138
+ except StopIteration :
139
+ return
140
+
141
+ if not isinstance (command , _Command ):
142
+ raise TypeError (
143
+ 'client_prog yielded invalid command {!r}' .format (command ))
144
+
145
+ command_res = command ._run (sock )
146
+ assert isinstance (command_res , tuple ) and len (command_res ) == 2
147
+
148
+ last_val = command_res [1 ]
149
+ sock = command_res [0 ]
150
+
151
+ def stop (self ):
152
+ self ._active = False
153
+ self .join ()
154
+
155
+ def __enter__ (self ):
156
+ self .start ()
157
+ return self
158
+
159
+ def __exit__ (self , * exc ):
160
+ self .stop ()
161
+
162
+
163
+ class Client (_Runner , threading .Thread ):
164
+
165
+ def __init__ (self , sock , prog , timeout ):
166
+ threading .Thread .__init__ (self , None , None , 'test-client' )
167
+ self .daemon = True
168
+
169
+ self ._timeout = timeout
170
+ self ._sock = sock
171
+ self ._active = True
172
+ self ._prog = prog
173
+
174
+ def run (self ):
175
+ prog = self ._prog ()
176
+ sock = self ._sock
177
+ self ._iterate (prog , sock )
178
+
179
+
180
+ class Server (_Runner , threading .Thread ):
139
181
140
182
def __init__ (self , sock , prog , timeout , max_clients ):
141
183
threading .Thread .__init__ (self , None , None , 'test-server' )
@@ -173,53 +215,20 @@ def run(self):
173
215
174
216
def _handle_client (self , sock ):
175
217
prog = self ._prog ()
176
-
177
- last_val = None
178
- while self ._active :
179
- try :
180
- command = prog .send (last_val )
181
- except StopIteration :
182
- self ._finished_clients += 1
183
- return
184
-
185
- if not isinstance (command , Command ):
186
- raise TypeError (
187
- 'server_prog yielded invalid command {!r}' .format (command ))
188
-
189
- command_res = command ._run (sock )
190
- assert isinstance (command_res , tuple ) and len (command_res ) == 2
191
-
192
- last_val = command_res [1 ]
193
- sock = command_res [0 ]
218
+ self ._iterate (prog , sock )
194
219
195
220
@property
196
221
def addr (self ):
197
222
return self ._sock .getsockname ()
198
223
199
- def stop (self ):
200
- self ._active = False
201
- self .join ()
202
-
203
- if self ._finished_clients != self ._clients :
204
- raise AssertionError (
205
- 'not all clients are finished: {!r}' .format (
206
- self ._clients - self ._finished_clients ))
207
-
208
- def __enter__ (self ):
209
- self .start ()
210
- return self
211
-
212
- def __exit__ (self , * exc ):
213
- self .stop ()
214
-
215
224
216
- class Command :
225
+ class _Command :
217
226
218
227
def _run (self , sock ):
219
228
raise NotImplementedError
220
229
221
230
222
- class write (Command ):
231
+ class write (_Command ):
223
232
224
233
def __init__ (self , data :bytes ):
225
234
self ._data = data
@@ -229,13 +238,22 @@ def _run(self, sock):
229
238
return sock , None
230
239
231
240
232
- class close (Command ):
241
+ class connect (_Command ):
242
+ def __init__ (self , addr ):
243
+ self ._addr = addr
244
+
245
+ def _run (self , sock ):
246
+ sock .connect (self ._addr )
247
+ return sock , None
248
+
249
+
250
+ class close (_Command ):
233
251
def _run (self , sock ):
234
252
sock .close ()
235
253
return sock , None
236
254
237
255
238
- class read (Command ):
256
+ class read (_Command ):
239
257
240
258
def __init__ (self , nbytes ):
241
259
self ._nbytes = nbytes
@@ -260,23 +278,27 @@ def _run(self, sock):
260
278
return sock , data
261
279
262
280
263
- class starttls (Command ):
281
+ class starttls (_Command ):
264
282
265
283
def __init__ (self , ssl_context , * ,
266
284
server_side = False ,
267
- server_hostname = None ):
285
+ server_hostname = None ,
286
+ do_handshake_on_connect = True ):
268
287
269
288
assert isinstance (ssl_context , ssl .SSLContext )
270
289
self ._ctx = ssl_context
271
290
272
291
self ._server_side = server_side
273
292
self ._server_hostname = server_hostname
293
+ self ._do_handshake_on_connect = do_handshake_on_connect
274
294
275
295
def _run (self , sock ):
276
296
ssl_sock = self ._ctx .wrap_socket (
277
297
sock , server_side = self ._server_side ,
278
- server_hostname = self ._server_hostname )
298
+ server_hostname = self ._server_hostname ,
299
+ do_handshake_on_connect = self ._do_handshake_on_connect )
279
300
280
- ssl_sock .do_handshake ()
301
+ if self ._server_side :
302
+ ssl_sock .do_handshake ()
281
303
282
304
return ssl_sock , None
0 commit comments