diff --git a/django_websocket/middleware.py b/django_websocket/middleware.py index 85bb779..8cd25f9 100644 --- a/django_websocket/middleware.py +++ b/django_websocket/middleware.py @@ -33,6 +33,6 @@ def process_view(self, request, view_func, view_args, view_kwargs): return HttpResponseBadRequest() def process_response(self, request, response): - if request.is_websocket() and request.websocket._handshake_sent: - request.websocket._send_closing_frame(True) + if request.is_websocket() and request.websocket.handshake_sent: + request.websocket.close() return response diff --git a/django_websocket/websocket.py b/django_websocket/websocket.py index 6a7b1a7..8c419c3 100644 --- a/django_websocket/websocket.py +++ b/django_websocket/websocket.py @@ -1,291 +1,248 @@ -import collections -import select -import string -import struct -try: - from hashlib import md5 -except ImportError: #pragma NO COVER - from md5 import md5 -from errno import EINTR -from socket import error as SocketError +from base64 import b64encode +from hashlib import sha1 +import types + +from ws4py import WS_KEY, WS_VERSION +from ws4py.messaging import Message +from ws4py.streaming import Stream class MalformedWebSocket(ValueError): pass -def _extract_number(value): - """ - Utility function which, given a string like 'g98sd 5[]221@1', will - return 4926105. Used to parse the Sec-WebSocket-Key headers. - - In other words, it extracts digits from a string and returns the number - due to the number of spaces. - """ - out = "" - spaces = 0 - for char in value: - if char in string.digits: - out += char - elif char == " ": - spaces += 1 - return int(out) / spaces - - def setup_websocket(request): - if request.META.get('HTTP_CONNECTION', None) == 'Upgrade' and \ - request.META.get('HTTP_UPGRADE', None) == 'WebSocket': - - # See if they sent the new-format headers - if 'HTTP_SEC_WEBSOCKET_KEY1' in request.META: - protocol_version = 76 - if 'HTTP_SEC_WEBSOCKET_KEY2' not in request.META: - raise MalformedWebSocket() - else: - protocol_version = 75 - - # If it's new-version, we need to work out our challenge response - if protocol_version == 76: - key1 = _extract_number(request.META['HTTP_SEC_WEBSOCKET_KEY1']) - key2 = _extract_number(request.META['HTTP_SEC_WEBSOCKET_KEY2']) - # There's no content-length header in the request, but it has 8 - # bytes of data. - key3 = request.META['wsgi.input'].read(8) - key = struct.pack(">II", key1, key2) + key3 - handshake_response = md5(key).digest() - - location = 'ws://%s%s' % (request.get_host(), request.path) - qs = request.META.get('QUERY_STRING') - if qs: - location += '?' + qs - if protocol_version == 75: - handshake_reply = ( - "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "WebSocket-Origin: %s\r\n" - "WebSocket-Location: %s\r\n\r\n" % ( - request.META.get('HTTP_ORIGIN'), - location)) - elif protocol_version == 76: - handshake_reply = ( - "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Origin: %s\r\n" - "Sec-WebSocket-Protocol: %s\r\n" - "Sec-WebSocket-Location: %s\r\n" % ( - request.META.get('HTTP_ORIGIN'), - request.META.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default'), - location)) - handshake_reply = str(handshake_reply) - handshake_reply = '%s\r\n%s' % (handshake_reply, handshake_response) + if 'Upgrade' in request.META.get('HTTP_CONNECTION', None) and \ + request.META.get('HTTP_UPGRADE', None).lower() == 'websocket': - else: - raise MalformedWebSocket("Unknown WebSocket protocol version.") + version = request.META.get('HTTP_SEC_WEBSOCKET_VERSION') + version_is_valid = False + if version: + try: + version = int(version) + except: + pass + else: + version_is_valid = version in WS_VERSION + + if not version_is_valid: + raise MalformedWebSocket + + # Compute the challenge response + key = request.META['HTTP_SEC_WEBSOCKET_KEY'] + handshake_response = b64encode(sha1(key + WS_KEY).digest()) + + # TODO : protocol negociation? Could be specified in the decorator... + protocols = request.META.get('HTTP_SEC_WEBSOCKET_PROTOCOL') + + # TODO : the 'Origin' field should be validated by application code + # (or configuration/per-view option) + + handshake_reply = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n") + handshake_reply += "Sec-WebSocket-Version: %s\r\n" % version + if protocols: + handshake_reply += "Sec-WebSocket-Protocol: %s\r\n" % protocols + handshake_reply += "Sec-WebSocket-Accept: %s\r\n" % handshake_response + handshake_reply += "\r\n" + + # Here we want to make sure that Django doesn't handle this request + # anymore + #request.META['wsgi.input']._sock = None socket = request.META['wsgi.input']._sock.dup() + # dup() is not portable because the folks writing Python forgot to + # backport it from 3.x to 2.7 + return WebSocket( socket, - protocol=request.META.get('HTTP_WEBSOCKET_PROTOCOL'), - version=protocol_version, - handshake_reply=handshake_reply, - ) + handshake_reply, + protocols) return None +DEFAULT_READING_SIZE = 2 + +# This class was adapted from ws4py.websocket:WebSocket +# Changes include: +# * process() inlined into _run, uses yield instead of received_message(), +# making it a generator +# * other callbacks removed +# * process() inlined into _run +# * added __iter__() +# * send_handshake() method class WebSocket(object): - """ - A websocket object that handles the details of - serialization/deserialization to the socket. - - The primary way to interact with a :class:`WebSocket` object is to - call :meth:`send` and :meth:`wait` in order to pass messages back - and forth with the browser. - """ - _socket_recv_bytes = 4096 - - - def __init__(self, socket, protocol, version=76, - handshake_reply=None, handshake_sent=None): - ''' - Arguments: - - - ``socket``: An open socket that should be used for WebSocket - communciation. - - ``protocol``: not used yet. - - ``version``: The WebSocket spec version to follow (default is 76) - - ``handshake_reply``: Handshake message that should be sent to the - client when ``send_handshake()`` is called. - - ``handshake_sent``: Whether the handshake is already sent or not. - Set to ``False`` to prevent ``send_handshake()`` to do anything. - ''' - self.socket = socket - self.protocol = protocol - self.version = version - self.closed = False + def __init__(self, sock, handshake_reply, protocols=None): + self.stream = Stream(always_mask=False) self.handshake_reply = handshake_reply - if handshake_sent is None: - self._handshake_sent = not bool(handshake_reply) - else: - self._handshake_sent = handshake_sent - self._buffer = "" - self._message_queue = collections.deque() + self.handshake_sent = False + self.protocols = protocols + self.sock = sock + self.client_terminated = False + self.server_terminated = False + self.reading_buffer_size = DEFAULT_READING_SIZE + self.sender = self.sock.sendall + + # This was initially a loop that used callbacks in ws4py + # Here it was turned into a generator, the callback replaced by yield + self.runner = self._run() def send_handshake(self): - self.socket.sendall(self.handshake_reply) - self._handshake_sent = True + self.sender(self.handshake_reply) + self.handshake_sent = True + + def wait(self): + """ + Reads a message from the websocket, blocking and responding to wire + messages until one becomes available. + """ + try: + return self.runner.next() + except StopIteration: + return None - @classmethod - def _pack_message(cls, message): - """Pack the message inside ``00`` and ``FF`` + def send(self, payload, binary=False): + """ + Sends the given ``payload`` out. - As per the dataframing section (5.3) for the websocket spec + If ``payload`` is some bytes or a bytearray, + then it is sent as a single message not fragmented. + + If ``payload`` is a generator, each chunk is sent as part of + fragmented message. + + If ``binary`` is set, handles the payload as a binary message. """ - if isinstance(message, unicode): - message = message.encode('utf-8') - elif not isinstance(message, str): - message = str(message) - packed = "\x00%s\xFF" % message - return packed - - def _parse_message_queue(self): - """ Parses for messages in the buffer *buf*. It is assumed that - the buffer contains the start character for a message, but that it - may contain only part of the rest of the message. - - Returns an array of messages, and the buffer remainder that - didn't contain any full messages.""" - msgs = [] - end_idx = 0 - buf = self._buffer - while buf: - frame_type = ord(buf[0]) - if frame_type == 0: - # Normal message. - end_idx = buf.find("\xFF") - if end_idx == -1: #pragma NO COVER + message_sender = self.stream.binary_message if binary else self.stream.text_message + + if isinstance(payload, basestring) or isinstance(payload, bytearray): + self.sender(message_sender(payload).single(mask=self.stream.always_mask)) + + elif isinstance(payload, Message): + self.sender(payload.single(mask=self.stream.always_mask)) + + elif type(payload) == types.GeneratorType: + bytes = payload.next() + first = True + for chunk in payload: + self.sender(message_sender(bytes).fragment(first=first, mask=self.stream.always_mask)) + bytes = chunk + first = False + + self.sender(message_sender(bytes).fragment(last=True, mask=self.stream.always_mask)) + + else: + raise ValueError("Unsupported type '%s' passed to send()" % type(payload)) + + def _cleanup(self): + """ + Frees up resources used by the endpoint. + """ + self.sender = None + self.sock = None + self.stream._cleanup() + self.stream = None + + def _run(self): + """ + Performs the operation of reading from the underlying + connection in order to feed the stream of bytes. + + We start with a small size of two bytes to be read + from the connection so that we can quickly parse an + incoming frame header. Then the stream indicates + whatever size must be read from the connection since + it knows the frame payload length. + + Note that we perform some automatic operations: + + * On a closing message, we respond with a closing + message and finally close the connection + * We respond to pings with pong messages. + * Whenever an error is raised by the stream parsing, + we initiate the closing of the connection with the + appropiate error code. + """ + self.sock.setblocking(True) + s = self.stream + try: + sock = self.sock + + while not self.terminated: + bytes = sock.recv(self.reading_buffer_size) + if not bytes and self.reading_buffer_size > 0: break - msgs.append(buf[1:end_idx].decode('utf-8', 'replace')) - buf = buf[end_idx+1:] - elif frame_type == 255: - # Closing handshake. - assert ord(buf[1]) == 0, "Unexpected closing handshake: %r" % buf - self.closed = True - break - else: - raise ValueError("Don't understand how to parse this type of message: %r" % buf) - self._buffer = buf - return msgs - - def send(self, message): - ''' - Send a message to the client. *message* should be convertable to a - string; unicode objects should be encodable as utf-8. - ''' - packed = self._pack_message(message) - self.socket.sendall(packed) - - def _socket_recv(self): - ''' - Gets new data from the socket and try to parse new messages. - ''' - delta = self.socket.recv(self._socket_recv_bytes) - if delta == '': - return False - self._buffer += delta - msgs = self._parse_message_queue() - self._message_queue.extend(msgs) - return True - - def _socket_can_recv(self, timeout=0.0): - ''' - Return ``True`` if new data can be read from the socket. - ''' - r, w, e = [self.socket], [], [] + + self.reading_buffer_size = s.parser.send(bytes) or DEFAULT_READING_SIZE + + if s.closing is not None: + if not self.server_terminated: + self.close(s.closing.code, s.closing.reason) + else: + self.client_terminated = True + break + + if s.errors: + for error in s.errors: + self.close(error.code, error.reason) + s.errors = [] + break + + if s.has_message: + yield s.message + s.message.data = None + s.message = None + else: + if s.pings: + for ping in s.pings: + self.sender(s.pong(ping.data)) + s.pings = [] + + if s.pongs: + s.pongs = [] + finally: + self.client_terminated = self.server_terminated = True + + s = sock = None + self.close_connection() + self._cleanup() + + def close(self, code=1000, reason=''): + """ + Call this method to initiate the websocket connection + closing by sending a close frame to the connected peer. + The ``code`` is the status code representing the + termination's reason. + + Once this method is called, the ``server_terminated`` + attribute is set. Calling this method several times is + safe as the closing frame will be sent only the first + time. + + .. seealso:: Defined Status Codes http://tools.ietf.org/html/rfc6455#section-7.4.1 + """ + if not self.server_terminated: + self.server_terminated = True + self.sender(self.stream.close(code=code, reason=reason).single(mask=self.stream.always_mask)) + + def close_connection(self): + """ + Shutdowns then closes the underlying connection. + """ try: - r, w, e = select.select(r, w, e, timeout) - except select.error, err: - if err.args[0] == EINTR: - return False - raise - return self.socket in r - - def _get_new_messages(self): - # read as long from socket as we need to get a new message. - while self._socket_can_recv(): - self._socket_recv() - if self._message_queue: - return - - def count_messages(self): - ''' - Returns the number of queued messages. - ''' - self._get_new_messages() - return len(self._message_queue) - - def has_messages(self): - ''' - Returns ``True`` if new messages from the socket are available, else - ``False``. - ''' - if self._message_queue: - return True - self._get_new_messages() - if self._message_queue: - return True - return False - - def read(self, fallback=None): - ''' - Return new message or ``fallback`` if no message is available. - ''' - if self.has_messages(): - return self._message_queue.popleft() - return fallback + self.sock.shutdown(socket.SHUT_RDWR) + self.sock.close() + except: + pass - def wait(self): - ''' - Waits for and deserializes messages. Returns a single message; the - oldest not yet processed. - ''' - while not self._message_queue: - # Websocket might be closed already. - if self.closed: - return None - # no parsed messages, must mean buf needs more data - new_data = self._socket_recv() - if not new_data: - return None - return self._message_queue.popleft() + @property + def terminated(self): + """ + Returns ``True`` if both the client and server have been + marked as terminated. + """ + return self.client_terminated is True and self.server_terminated is True def __iter__(self): - ''' - Use ``WebSocket`` as iterator. Iteration only stops when the websocket - gets closed by the client. - ''' - while True: - message = self.wait() - if message is None: - return - yield message - - def _send_closing_frame(self, ignore_send_errors=False): - ''' - Sends the closing frame to the client, if required. - ''' - if self.version == 76 and not self.closed: - try: - self.socket.sendall("\xff\x00") - except SocketError: - # Sometimes, like when the remote side cuts off the connection, - # we don't care about this. - if not ignore_send_errors: - raise - self.closed = True - - def close(self): - ''' - Forcibly close the websocket. - ''' - self._send_closing_frame() + return self.runner diff --git a/requirements.txt b/requirements.txt index 088efaf..af935ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ Django==1.4.0 mock==0.8.0 +ws4py