Skip to content

Commit d2286ff

Browse files
committed
Merge pull request #150 from flying-sheep/patch-1
Add websocket subprotocol support
2 parents e494364 + 92ce18b commit d2286ff

File tree

3 files changed

+76
-14
lines changed

3 files changed

+76
-14
lines changed

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ Sebastian Hanula
2222
Simon Kennedy
2323
Vaibhav Sagar
2424
Vitaly Haritonsky
25+
Philipp A.

aiohttp/websocket.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
2323
WS_HDRS = ('UPGRADE', 'CONNECTION',
24-
'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY')
24+
'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY', 'SEC-WEBSOCKET-PROTOCOL')
2525

2626
Message = collections.namedtuple('Message', ['tp', 'data', 'extra'])
2727

@@ -182,10 +182,14 @@ def close(self, code=1000, message=b''):
182182
opcode=OPCODE_CLOSE)
183183

184184

185-
def do_handshake(method, headers, transport):
185+
def do_handshake(method, headers, transport, protocols=()):
186186
"""Prepare WebSocket handshake. It return http response code,
187187
response headers, websocket parser, websocket writer. It does not
188-
perform any IO."""
188+
perform any IO.
189+
190+
`protocols` is a sequence of known protocols. On successful handshake,
191+
the returned response headers contain the first protocol in this list
192+
which the server also knows."""
189193

190194
# WebSocket accepts only GET
191195
if method.upper() != 'GET':
@@ -201,6 +205,21 @@ def do_handshake(method, headers, transport):
201205
raise errors.HttpBadRequest(
202206
'No CONNECTION upgrade hdr: {}'.format(
203207
headers.get('CONNECTION')))
208+
209+
# find common sub-protocol between client and server
210+
protocol = None
211+
if 'SEC-WEBSOCKET-PROTOCOL' in headers:
212+
req_protocols = {str(proto.strip()) for proto in
213+
headers['SEC-WEBSOCKET-PROTOCOL'].split(',')}
214+
215+
for proto in protocols:
216+
if proto in req_protocols:
217+
protocol = proto
218+
break
219+
else:
220+
raise errors.HttpBadRequest(
221+
'Client protocols {!r} don’t overlap server-known ones {!r}'
222+
.format(protocols, req_protocols))
204223

205224
# check supported version
206225
version = headers.get('SEC-WEBSOCKET-VERSION')
@@ -218,12 +237,18 @@ def do_handshake(method, headers, transport):
218237
raise errors.HttpBadRequest(
219238
'Handshake error: {!r}'.format(key)) from None
220239

221-
# response code, headers, parser, writer
222-
return (101,
223-
(('UPGRADE', 'websocket'),
240+
response_headers = [('UPGRADE', 'websocket'),
224241
('CONNECTION', 'upgrade'),
225242
('TRANSFER-ENCODING', 'chunked'),
226243
('SEC-WEBSOCKET-ACCEPT', base64.b64encode(
227-
hashlib.sha1(key.encode() + WS_KEY).digest()).decode())),
244+
hashlib.sha1(key.encode() + WS_KEY).digest()).decode())]
245+
246+
if protocol:
247+
response_headers.append(('SEC-WEBSOCKET-PROTOCOL', protocol))
248+
249+
# response code, headers, parser, writer, protocol
250+
return (101,
251+
response_headers,
228252
WebSocketParser,
229-
WebSocketWriter(transport))
253+
WebSocketWriter(transport),
254+
protocol)

tests/test_websocket.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,18 +419,54 @@ def test_protocol_key(self):
419419
websocket.do_handshake,
420420
self.message.method, self.message.headers, self.transport)
421421

422+
def gen_ws_headers(self, protocols=''):
423+
key = base64.b64encode(os.urandom(16)).decode()
424+
hdrs = [('UPGRADE', 'websocket'),
425+
('CONNECTION', 'upgrade'),
426+
('SEC-WEBSOCKET-VERSION', '13'),
427+
('SEC-WEBSOCKET-KEY', key)]
428+
if protocols:
429+
hdrs += [('SEC-WEBSOCKET-PROTOCOL', protocols)]
430+
return hdrs, key
431+
422432
def test_handshake(self):
423-
sec_key = base64.b64encode(os.urandom(16)).decode()
433+
hdrs, sec_key = self.gen_ws_headers()
424434

425-
self.headers.extend([('UPGRADE', 'websocket'),
426-
('CONNECTION', 'upgrade'),
427-
('SEC-WEBSOCKET-VERSION', '13'),
428-
('SEC-WEBSOCKET-KEY', sec_key)])
429-
status, headers, parser, writer = websocket.do_handshake(
435+
self.headers.extend(hdrs)
436+
status, headers, parser, writer, protocol = websocket.do_handshake(
430437
self.message.method, self.message.headers, self.transport)
431438
self.assertEqual(status, 101)
439+
self.assertIsNone(protocol)
432440

433441
key = base64.b64encode(
434442
hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest())
435443
headers = dict(headers)
436444
self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode())
445+
446+
def test_handshake_protocol(self):
447+
'''Tests if one protocol is returned by do_handshake'''
448+
proto = 'chat'
449+
450+
self.headers.extend(self.gen_ws_headers(proto)[0])
451+
_, resp_headers, _, _, protocol = websocket.do_handshake(
452+
self.message.method, self.message.headers, self.transport,
453+
protocols=[proto])
454+
455+
self.assertEqual(protocol, proto)
456+
457+
#also test if we reply with the protocol
458+
resp_headers = dict(resp_headers)
459+
self.assertEqual(resp_headers['SEC-WEBSOCKET-PROTOCOL'], proto)
460+
461+
def test_handshake_protocol_agreement(self):
462+
'''Tests if the right protocol is selected given multiple'''
463+
best_proto = 'chat'
464+
wanted_protos = ['best', 'chat', 'worse_proto']
465+
server_protos = 'worse_proto,chat'
466+
467+
self.headers.extend(self.gen_ws_headers(server_protos)[0])
468+
_, resp_headers, _, _, protocol = websocket.do_handshake(
469+
self.message.method, self.message.headers, self.transport,
470+
protocols=wanted_protos)
471+
472+
self.assertEqual(protocol, best_proto)

0 commit comments

Comments
 (0)