Skip to content
This repository was archived by the owner on Jan 13, 2021. It is now read-only.

Commit d395d57

Browse files
committed
Merge pull request #186 from irvind/override-default-headers
Allow to override default request headers
2 parents 31cbc84 + f8a0b04 commit d395d57

File tree

6 files changed

+151
-27
lines changed

6 files changed

+151
-27
lines changed

hyper/common/headers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,19 @@ def iter_raw(self):
181181
for item in self._items:
182182
yield item
183183

184+
def replace(self, key, value):
185+
"""
186+
Replace existing header with new value. If header doesn't exist this
187+
method work like ``__setitem__``. Replacing leads to deletion of all
188+
exsiting headers with the same name.
189+
"""
190+
try:
191+
del self[key]
192+
except KeyError:
193+
pass
194+
195+
self[key] = value
196+
184197
def merge(self, other):
185198
"""
186199
Merge another header set or any other dict-like into this one.

hyper/common/util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
"""
88
from hyper.compat import unicode, bytes, imap
99
from ..packages.rfc3986.uri import URIReference
10+
from ..compat import is_py3
1011
import re
1112

13+
1214
def to_bytestring(element):
1315
"""
1416
Converts a single string to a bytestring, encoding via UTF-8 if needed.
@@ -28,6 +30,7 @@ def to_bytestring_tuple(*x):
2830
"""
2931
return tuple(imap(to_bytestring, x))
3032

33+
3134
def to_host_port_tuple(host_port_str, default_port=80):
3235
"""
3336
Converts the given string containing a host and possibly a port
@@ -48,3 +51,10 @@ def to_host_port_tuple(host_port_str, default_port=80):
4851
port = int(uri.port)
4952

5053
return (host, port)
54+
55+
56+
def to_native_string(string, encoding='utf-8'):
57+
if isinstance(string, str):
58+
return string
59+
60+
return string.decode(encoding) if is_py3 else string.encode(encoding)

hyper/http20/connection.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..common.exceptions import ConnectionResetError
1010
from ..common.bufsocket import BufferedSocket
1111
from ..common.headers import HTTPHeaderMap
12-
from ..common.util import to_host_port_tuple
12+
from ..common.util import to_host_port_tuple, to_native_string
1313
from ..packages.hyperframe.frame import (
1414
FRAMES, DataFrame, HeadersFrame, PushPromiseFrame, RstStreamFrame,
1515
SettingsFrame, Frame, WindowUpdateFrame, GoAwayFrame, PingFrame,
@@ -170,8 +170,10 @@ def request(self, method, url, body=None, headers={}):
170170
"""
171171
stream_id = self.putrequest(method, url)
172172

173+
default_headers = (':method', ':scheme', ':authority', ':path')
173174
for name, value in headers.items():
174-
self.putheader(name, value, stream_id)
175+
is_default = to_native_string(name) in default_headers
176+
self.putheader(name, value, stream_id, replace=is_default)
175177

176178
# Convert the body to bytes if needed.
177179
if isinstance(body, str):
@@ -319,7 +321,7 @@ def putrequest(self, method, selector, **kwargs):
319321

320322
return s.stream_id
321323

322-
def putheader(self, header, argument, stream_id=None):
324+
def putheader(self, header, argument, stream_id=None, replace=False):
323325
"""
324326
Sends an HTTP header to the server, with name ``header`` and value
325327
``argument``.
@@ -341,7 +343,7 @@ def putheader(self, header, argument, stream_id=None):
341343
:returns: Nothing.
342344
"""
343345
stream = self._get_stream(stream_id)
344-
stream.add_header(header, argument)
346+
stream.add_header(header, argument, replace)
345347

346348
return
347349

hyper/http20/stream.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self,
6161
local_closed=False):
6262
self.stream_id = stream_id
6363
self.state = STATE_HALF_CLOSED_LOCAL if local_closed else STATE_IDLE
64-
self.headers = []
64+
self.headers = HTTPHeaderMap()
6565

6666
# Set to a key-value set of the response headers once their
6767
# HEADERS..CONTINUATION frame sequence finishes.
@@ -109,11 +109,15 @@ def __init__(self,
109109
self._encoder = header_encoder
110110
self._decoder = header_decoder
111111

112-
def add_header(self, name, value):
112+
def add_header(self, name, value, replace=False):
113113
"""
114114
Adds a single HTTP header to the headers to be sent on the request.
115115
"""
116-
self.headers.append((name.lower(), value))
116+
if not replace:
117+
self.headers[name] = value
118+
else:
119+
self.headers.replace(name, value)
120+
117121

118122
def send_data(self, data, final):
119123
"""
@@ -270,6 +274,7 @@ def open(self, end):
270274
"""
271275
# Strip any headers invalid in H2.
272276
headers = h2_safe_headers(self.headers)
277+
273278
# Encode the headers.
274279
encoded_headers = self._encoder.encode(headers)
275280

test/test_headers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,22 @@ def test_merge_header_map_dict(self):
258258
(b'hi', b'there'),
259259
(b'cat', b'dog'),
260260
]
261+
262+
def test_replacing(self):
263+
h = HTTPHeaderMap([
264+
(b'name', b'value'),
265+
(b'name2', b'value2'),
266+
(b'name2', b'value2'),
267+
(b'name3', b'value3'),
268+
])
269+
270+
h.replace('name2', '42')
271+
h.replace('name4', 'other_value')
272+
273+
assert list(h.items()) == [
274+
(b'name', b'value'),
275+
(b'name3', b'value3'),
276+
(b'name2', b'42'),
277+
(b'name4', b'other_value'),
278+
]
279+

test/test_hyper.py

Lines changed: 95 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
combine_repeated_headers, split_repeated_headers, h2_safe_headers
1919
)
2020
from hyper.common.headers import HTTPHeaderMap
21-
from hyper.compat import zlib_compressobj
21+
from hyper.compat import zlib_compressobj, is_py2
2222
from hyper.contrib import HTTP20Adapter
2323
import hyper.http20.errors as errors
2424
import errno
@@ -29,6 +29,7 @@
2929
from io import BytesIO
3030
import hyper
3131

32+
3233
def decode_frame(frame_data):
3334
f, length = Frame.parse_frame_header(frame_data[:9])
3435
f.parse_body(memoryview(frame_data[9:9 + length]))
@@ -87,11 +88,11 @@ def test_putrequest_autosets_headers(self):
8788
c.putrequest('GET', '/')
8889
s = c.recent_stream
8990

90-
assert s.headers == [
91-
(':method', 'GET'),
92-
(':scheme', 'https'),
93-
(':authority', 'www.google.com'),
94-
(':path', '/'),
91+
assert list(s.headers.items()) == [
92+
(b':method', b'GET'),
93+
(b':scheme', b'https'),
94+
(b':authority', b'www.google.com'),
95+
(b':path', b'/'),
9596
]
9697

9798
def test_putheader_puts_headers(self):
@@ -101,12 +102,29 @@ def test_putheader_puts_headers(self):
101102
c.putheader('name', 'value')
102103
s = c.recent_stream
103104

104-
assert s.headers == [
105-
(':method', 'GET'),
106-
(':scheme', 'https'),
107-
(':authority', 'www.google.com'),
108-
(':path', '/'),
109-
('name', 'value'),
105+
assert list(s.headers.items()) == [
106+
(b':method', b'GET'),
107+
(b':scheme', b'https'),
108+
(b':authority', b'www.google.com'),
109+
(b':path', b'/'),
110+
(b'name', b'value'),
111+
]
112+
113+
def test_putheader_replaces_headers(self):
114+
c = HTTP20Connection("www.google.com")
115+
116+
c.putrequest('GET', '/')
117+
c.putheader(':authority', 'www.example.org', replace=True)
118+
c.putheader('name', 'value')
119+
c.putheader('name', 'value2', replace=True)
120+
s = c.recent_stream
121+
122+
assert list(s.headers.items()) == [
123+
(b':method', b'GET'),
124+
(b':scheme', b'https'),
125+
(b':path', b'/'),
126+
(b':authority', b'www.example.org'),
127+
(b'name', b'value2'),
110128
]
111129

112130
def test_endheaders_sends_data(self):
@@ -203,6 +221,33 @@ def test_putrequest_sends_data(self):
203221
assert len(sock.queue) == 2
204222
assert c._out_flow_control_window == 65535 - len(b'hello')
205223

224+
def test_different_request_headers(self):
225+
sock = DummySocket()
226+
227+
c = HTTP20Connection('www.google.com')
228+
c._sock = sock
229+
c.request('GET', '/', body='hello', headers={b'name': b'value'})
230+
s = c.recent_stream
231+
232+
assert list(s.headers.items()) == [
233+
(b':method', b'GET'),
234+
(b':scheme', b'https'),
235+
(b':authority', b'www.google.com'),
236+
(b':path', b'/'),
237+
(b'name', b'value'),
238+
]
239+
240+
c.request('GET', '/', body='hello', headers={u'name2': u'value2'})
241+
s = c.recent_stream
242+
243+
assert list(s.headers.items()) == [
244+
(b':method', b'GET'),
245+
(b':scheme', b'https'),
246+
(b':authority', b'www.google.com'),
247+
(b':path', b'/'),
248+
(b'name2', b'value2'),
249+
]
250+
206251
def test_closed_connections_are_reset(self):
207252
c = HTTP20Connection('www.google.com')
208253
c._sock = DummySocket()
@@ -502,11 +547,11 @@ def test_that_using_proxy_keeps_http_headers_intact(self):
502547
c.request('GET', '/')
503548
s = c.recent_stream
504549

505-
assert s.headers == [
506-
(':method', 'GET'),
507-
(':scheme', 'http'),
508-
(':authority', 'www.google.com'),
509-
(':path', '/'),
550+
assert list(s.headers.items()) == [
551+
(b':method', b'GET'),
552+
(b':scheme', b'http'),
553+
(b':authority', b'www.google.com'),
554+
(b':path', b'/'),
510555
]
511556

512557
def test_recv_cb_n_times(self):
@@ -695,13 +740,30 @@ def test_streams_have_ids(self):
695740

696741
def test_streams_initially_have_no_headers(self):
697742
s = Stream(1, None, None, None, None, None, None)
698-
assert s.headers == []
743+
assert list(s.headers.items()) == []
699744

700745
def test_streams_can_have_headers(self):
701746
s = Stream(1, None, None, None, None, None, None)
702747
s.add_header("name", "value")
703-
assert s.headers == [("name", "value")]
748+
assert list(s.headers.items()) == [(b"name", b"value")]
749+
750+
def test_streams_can_replace_headers(self):
751+
s = Stream(1, None, None, None, None, None, None)
752+
s.add_header("name", "value")
753+
s.add_header("name", "other_value", replace=True)
704754

755+
assert list(s.headers.items()) == [(b"name", b"other_value")]
756+
757+
def test_streams_can_replace_none_headers(self):
758+
s = Stream(1, None, None, None, None, None, None)
759+
s.add_header("name", "value")
760+
s.add_header("other_name", "other_value", replace=True)
761+
762+
assert list(s.headers.items()) == [
763+
(b"name", b"value"),
764+
(b"other_name", b"other_value")
765+
]
766+
705767
def test_stream_opening_sends_headers(self):
706768
def data_callback(frame):
707769
assert isinstance(frame, HeadersFrame)
@@ -1465,11 +1527,23 @@ def test_connection_error_when_send_out_of_range_frame(self):
14651527
with pytest.raises(ValueError):
14661528
c._send_cb(d)
14671529

1530+
14681531
# Some utility classes for the tests.
14691532
class NullEncoder(object):
14701533
@staticmethod
14711534
def encode(headers):
1472-
return '\n'.join("%s%s" % (name, val) for name, val in headers)
1535+
1536+
def to_str(v):
1537+
if is_py2:
1538+
return str(v)
1539+
else:
1540+
if not isinstance(v, str):
1541+
v = str(v, 'utf-8')
1542+
return v
1543+
1544+
return '\n'.join("%s%s" % (to_str(name), to_str(val))
1545+
for name, val in headers)
1546+
14731547

14741548
class FixedDecoder(object):
14751549
def __init__(self, result):
@@ -1478,6 +1552,7 @@ def __init__(self, result):
14781552
def decode(self, headers):
14791553
return self.result
14801554

1555+
14811556
class DummySocket(object):
14821557
def __init__(self):
14831558
self.queue = []

0 commit comments

Comments
 (0)