@@ -419,18 +419,54 @@ def test_protocol_key(self):
419
419
websocket .do_handshake ,
420
420
self .message .method , self .message .headers , self .transport )
421
421
422
+ def gen_ws_headers (self , protocols = '' ):
423
+ key = base64 .b64encode (os .urandom (16 )).decode ()
424
+ hdrs = [('UPGRADE' , 'websocket' ),
425
+ ('CONNECTION' , 'upgrade' ),
426
+ ('SEC-WEBSOCKET-VERSION' , '13' ),
427
+ ('SEC-WEBSOCKET-KEY' , key )]
428
+ if protocols :
429
+ hdrs += [('SEC-WEBSOCKET-PROTOCOL' , protocols )]
430
+ return hdrs , key
431
+
422
432
def test_handshake (self ):
423
- sec_key = base64 . b64encode ( os . urandom ( 16 )). decode ()
433
+ hdrs , sec_key = self . gen_ws_headers ()
424
434
425
- self .headers .extend ([('UPGRADE' , 'websocket' ),
426
- ('CONNECTION' , 'upgrade' ),
427
- ('SEC-WEBSOCKET-VERSION' , '13' ),
428
- ('SEC-WEBSOCKET-KEY' , sec_key )])
429
- status , headers , parser , writer = websocket .do_handshake (
435
+ self .headers .extend (hdrs )
436
+ status , headers , parser , writer , protocol = websocket .do_handshake (
430
437
self .message .method , self .message .headers , self .transport )
431
438
self .assertEqual (status , 101 )
439
+ self .assertIsNone (protocol )
432
440
433
441
key = base64 .b64encode (
434
442
hashlib .sha1 (sec_key .encode () + websocket .WS_KEY ).digest ())
435
443
headers = dict (headers )
436
444
self .assertEqual (headers ['SEC-WEBSOCKET-ACCEPT' ], key .decode ())
445
+
446
+ def test_handshake_protocol (self ):
447
+ '''Tests if one protocol is returned by do_handshake'''
448
+ proto = 'chat'
449
+
450
+ self .headers .extend (self .gen_ws_headers (proto )[0 ])
451
+ _ , resp_headers , _ , _ , protocol = websocket .do_handshake (
452
+ self .message .method , self .message .headers , self .transport ,
453
+ protocols = [proto ])
454
+
455
+ self .assertEqual (protocol , proto )
456
+
457
+ #also test if we reply with the protocol
458
+ resp_headers = dict (resp_headers )
459
+ self .assertEqual (resp_headers ['SEC-WEBSOCKET-PROTOCOL' ], proto )
460
+
461
+ def test_handshake_protocol_agreement (self ):
462
+ '''Tests if the right protocol is selected given multiple'''
463
+ best_proto = 'chat'
464
+ wanted_protos = ['best' , 'chat' , 'worse_proto' ]
465
+ server_protos = 'worse_proto,chat'
466
+
467
+ self .headers .extend (self .gen_ws_headers (server_protos )[0 ])
468
+ _ , resp_headers , _ , _ , protocol = websocket .do_handshake (
469
+ self .message .method , self .message .headers , self .transport ,
470
+ protocols = wanted_protos )
471
+
472
+ self .assertEqual (protocol , best_proto )
0 commit comments