diff --git a/tests/http_server.py b/tests/http_server.py index bee754e..a80f172 100644 --- a/tests/http_server.py +++ b/tests/http_server.py @@ -131,8 +131,8 @@ async def on_request(**server): async def on_response(content_type=b'text/plain', **server): response = server['response'] - assert response.headers[b'x-foo'] == [b'X-Foo: bar'] - assert b'Set-Cookie: sess=www; ' in response.headers[b'set-cookie'][0] + assert response.headers[b'x-foo'] == [b'bar'] + assert b'sess=www; ' in response.headers[b'set-cookie'][0] response.set_header(b'X-Foo', b'baz') diff --git a/tremolo/lib/http_response.py b/tremolo/lib/http_response.py index e4bfd8d..615fa67 100644 --- a/tremolo/lib/http_response.py +++ b/tremolo/lib/http_response.py @@ -24,6 +24,14 @@ CONNECTIONS = (b'close', b'keep-alive', b'upgrade') +class HeaderKey(bytes): + def __new__(cls, name): + obj = super().__new__(cls, name.lower()) + obj.name = name # original casing + + return obj + + class HTTPResponse(Response): __slots__ = ('line', 'content_type', 'http_chunked', '_headers') @@ -59,12 +67,12 @@ def append_header(self, name, value): if b'\r' in name or b'\n' in name or b'\r' in value or b'\n' in value: raise InternalServerError - key = name.lower() + key = HeaderKey(name) if key in self.headers: - self.headers[key].append(name + b': ' + value) + self.headers[key].append(value) else: - self.headers[key] = [name + b': ' + value] + self.headers[key] = [value] def set_header(self, name, value=''): if isinstance(name, str): @@ -76,7 +84,7 @@ def set_header(self, name, value=''): if b'\r' in name or b'\n' in name or b'\r' in value or b'\n' in value: raise InternalServerError - self.headers[name.lower()] = [name + b': ' + value] + self.headers[HeaderKey(name)] = [value] def set_base_headers(self): self.set_header( @@ -179,8 +187,9 @@ async def end(self, data=b'', *, keepalive=True, **kwargs): self.content_type, content_length, CONNECTIONS[keepalive and self.request.http_keepalive], - b'\r\n'.join(b'\r\n'.join(v) for k, v in self.headers.items() - if k not in excludes), + b'\r\n'.join(k.name + b': ' + value for k in self.headers + if k not in excludes + for value in self.headers[k]), data), **kwargs ) self.headers_sent(True) @@ -245,7 +254,8 @@ async def write(self, data=None, *, chunked=None, buffer_size=16384, await self.send( b' '.join(self.line) + b'\r\n' + - b'\r\n'.join(b'\r\n'.join(v) for v in self.headers.values()) + + b'\r\n'.join(k.name + b': ' + value for k in self.headers + for value in self.headers[k]) + b'\r\n\r\n' ) self.headers_sent(True)