Skip to content

Commit 5b551a8

Browse files
committed
Fix stable ordering for multidict iterations, add more tests
1 parent 51e2476 commit 5b551a8

File tree

2 files changed

+50
-48
lines changed

2 files changed

+50
-48
lines changed

aiohttp/multidict.py

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
import pprint
2-
from itertools import chain
2+
from itertools import chain, filterfalse
33
from collections import abc
44

55
_marker = object()
66

77

8+
def _unique_everseen(iterable):
9+
"""List unique elements, preserving order.
10+
Remember all elements ever seen.
11+
Recipe from
12+
https://docs.python.org/3/library/itertools.html#itertools-recipes"""
13+
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
14+
# unique_everseen('ABBCcAD', str.lower) --> A B C D
15+
seen = set()
16+
seen_add = seen.add
17+
for element in filterfalse(seen.__contains__, iterable):
18+
seen_add(element)
19+
yield element
20+
21+
822
class MultiDict(abc.Mapping):
923
"""Read-only ordered dictionary that can have multiple values for each key.
1024
@@ -25,10 +39,8 @@ def __init__(self, *args, **kwargs):
2539
args = list(args[0])
2640
for arg in args:
2741
if not len(arg) == 2:
28-
raise TypeError(
29-
"MultiDict takes either dict or list of \
30-
(key, value) tuples"
31-
)
42+
raise TypeError("MultiDict takes either dict "
43+
"or list of (key, value) tuples")
3244

3345
self._fill(chain(args, kwargs.items()))
3446

@@ -217,13 +229,13 @@ def __delitem__(self, key):
217229
super().__delitem__(key.upper())
218230

219231

220-
class _ItemsView(abc.ItemsView):
232+
class _ViewBase:
221233

222234
def __init__(self, items, *, getall=False):
223235
self._getall = getall
224236
self._keys = [item[0] for item in items]
225237
if not getall:
226-
self._keys = set(self._keys)
238+
self._keys = list(_unique_everseen(self._keys))
227239

228240
items_to_use = []
229241
if getall:
@@ -238,6 +250,9 @@ def __init__(self, items, *, getall=False):
238250

239251
super().__init__(items_to_use)
240252

253+
254+
class _ItemsView(_ViewBase, abc.ItemsView):
255+
241256
def __contains__(self, item):
242257
assert isinstance(item, tuple) or isinstance(item, list)
243258
assert len(item) == 2
@@ -247,27 +262,7 @@ def __iter__(self):
247262
yield from self._mapping
248263

249264

250-
class _ValuesView(abc.ValuesView):
251-
252-
def __init__(self, items, *, getall=False):
253-
self._getall = getall
254-
self._keys = [item[0] for item in items]
255-
if not getall:
256-
self._keys = set(self._keys)
257-
258-
items_to_use = []
259-
if getall:
260-
items_to_use = items
261-
else:
262-
for key in self._keys:
263-
for k, v in items:
264-
if k == key:
265-
items_to_use.append((k, v))
266-
break
267-
268-
assert len(items_to_use) == len(self._keys)
269-
270-
super().__init__(items_to_use)
265+
class _ValuesView(_ViewBase, abc.ValuesView):
271266

272267
def __contains__(self, value):
273268
for item in self._mapping:
@@ -280,26 +275,7 @@ def __iter__(self):
280275
yield item[1]
281276

282277

283-
class _KeysView(abc.KeysView):
284-
285-
def __init__(self, items, *, getall=False):
286-
self._getall = getall
287-
self._keys = [item[0] for item in items]
288-
if not getall:
289-
self._keys = set(self._keys)
290-
291-
items_to_use = []
292-
if getall:
293-
items_to_use = items
294-
else:
295-
for key in self._keys:
296-
for k, v in items:
297-
if k == key:
298-
items_to_use.append((k, v))
299-
break
300-
assert len(items_to_use) == len(self._keys)
301-
302-
super().__init__(items_to_use)
278+
class _KeysView(_ViewBase, abc.KeysView):
303279

304280
def __contains__(self, key):
305281
return key in self._keys

tests/test_multidict.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,29 @@ def test_getone(self):
6767
with self.assertRaises(KeyError):
6868
d.getone('key2')
6969

70+
self.assertEqual('default', d.getone('key2', 'default'))
71+
7072
def test_copy(self):
7173
d1 = self.make_dict(key='value', a='b')
7274

7375
d2 = d1.copy()
7476
self.assertEqual(d1, d2)
7577
self.assertIsNot(d1, d2)
7678

79+
def test_keys__contains(self):
80+
d = self.make_dict([('key', 'one'), ('key2', 'two'), ('key', 3)])
81+
self.assertEqual(list(d.keys()), ['key', 'key2'])
82+
self.assertEqual(list(d.keys(getall=True)), ['key', 'key2', 'key'])
83+
84+
self.assertIn('key', d.keys())
85+
self.assertIn('key2', d.keys())
86+
87+
self.assertIn('key', d.keys(getall=True))
88+
self.assertIn('key2', d.keys(getall=True))
89+
90+
self.assertNotIn('foo', d.keys())
91+
self.assertNotIn('foo', d.keys(getall=True))
92+
7793
def test_values__contains(self):
7894
d = self.make_dict([('key', 'one'), ('key', 'two'), ('key', 3)])
7995
self.assertEqual(list(d.values()), ['one'])
@@ -107,6 +123,10 @@ def test_items__contains(self):
107123
self.assertNotIn(('foo', 'bar'), d.items())
108124
self.assertNotIn(('foo', 'bar'), d.items(getall=True))
109125

126+
def test_cannot_create_from_unaccepted(self):
127+
with self.assertRaises(TypeError):
128+
self.make_dict([(1, 2, 3)])
129+
110130

111131
class MultiDictTests(_BaseTest, unittest.TestCase):
112132

@@ -134,6 +154,12 @@ def test_getall(self):
134154
default = object()
135155
self.assertIs(d.getall('some_key', default), default)
136156

157+
def test_preserve_stable_ordering(self):
158+
d = self.make_dict([('a', 1), ('b', '2'), ('a', 3)])
159+
s = '&'.join('{}={}'.format(k, v) for k, v in d.items(getall=True))
160+
161+
self.assertEqual('a=1&b=2&a=3', s)
162+
137163

138164
class CaseInsensitiveMultiDictTests(unittest.TestCase):
139165

0 commit comments

Comments
 (0)