Skip to content

Commit 92ce18b

Browse files
committed
broke websocket API and added tests
1 parent a5b200e commit 92ce18b

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

aiohttp/websocket.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def do_handshake(method, headers, transport, protocols=()):
218218
break
219219
else:
220220
raise errors.HttpBadRequest(
221-
'Client protocols {} don’t overlap server-known ones {}'
221+
'Client protocols {!r} don’t overlap server-known ones {!r}'
222222
.format(protocols, req_protocols))
223223

224224
# check supported version
@@ -246,8 +246,9 @@ def do_handshake(method, headers, transport, protocols=()):
246246
if protocol:
247247
response_headers.append(('SEC-WEBSOCKET-PROTOCOL', protocol))
248248

249-
# response code, headers, parser, writer
249+
# response code, headers, parser, writer, protocol
250250
return (101,
251251
response_headers,
252252
WebSocketParser,
253-
WebSocketWriter(transport))
253+
WebSocketWriter(transport),
254+
protocol)

tests/test_websocket.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,18 +419,54 @@ def test_protocol_key(self):
419419
websocket.do_handshake,
420420
self.message.method, self.message.headers, self.transport)
421421

422+
def gen_ws_headers(self, protocols=''):
423+
key = base64.b64encode(os.urandom(16)).decode()
424+
hdrs = [('UPGRADE', 'websocket'),
425+
('CONNECTION', 'upgrade'),
426+
('SEC-WEBSOCKET-VERSION', '13'),
427+
('SEC-WEBSOCKET-KEY', key)]
428+
if protocols:
429+
hdrs += [('SEC-WEBSOCKET-PROTOCOL', protocols)]
430+
return hdrs, key
431+
422432
def test_handshake(self):
423-
sec_key = base64.b64encode(os.urandom(16)).decode()
433+
hdrs, sec_key = self.gen_ws_headers()
424434

425-
self.headers.extend([('UPGRADE', 'websocket'),
426-
('CONNECTION', 'upgrade'),
427-
('SEC-WEBSOCKET-VERSION', '13'),
428-
('SEC-WEBSOCKET-KEY', sec_key)])
429-
status, headers, parser, writer = websocket.do_handshake(
435+
self.headers.extend(hdrs)
436+
status, headers, parser, writer, protocol = websocket.do_handshake(
430437
self.message.method, self.message.headers, self.transport)
431438
self.assertEqual(status, 101)
439+
self.assertIsNone(protocol)
432440

433441
key = base64.b64encode(
434442
hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest())
435443
headers = dict(headers)
436444
self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode())
445+
446+
def test_handshake_protocol(self):
447+
'''Tests if one protocol is returned by do_handshake'''
448+
proto = 'chat'
449+
450+
self.headers.extend(self.gen_ws_headers(proto)[0])
451+
_, resp_headers, _, _, protocol = websocket.do_handshake(
452+
self.message.method, self.message.headers, self.transport,
453+
protocols=[proto])
454+
455+
self.assertEqual(protocol, proto)
456+
457+
#also test if we reply with the protocol
458+
resp_headers = dict(resp_headers)
459+
self.assertEqual(resp_headers['SEC-WEBSOCKET-PROTOCOL'], proto)
460+
461+
def test_handshake_protocol_agreement(self):
462+
'''Tests if the right protocol is selected given multiple'''
463+
best_proto = 'chat'
464+
wanted_protos = ['best', 'chat', 'worse_proto']
465+
server_protos = 'worse_proto,chat'
466+
467+
self.headers.extend(self.gen_ws_headers(server_protos)[0])
468+
_, resp_headers, _, _, protocol = websocket.do_handshake(
469+
self.message.method, self.message.headers, self.transport,
470+
protocols=wanted_protos)
471+
472+
self.assertEqual(protocol, best_proto)

0 commit comments

Comments
 (0)