Skip to content

Commit 5c01f86

Browse files
committed
more work on multidict
1 parent 2a37e56 commit 5c01f86

File tree

1 file changed

+55
-69
lines changed

1 file changed

+55
-69
lines changed

aiohttp/multidict.py

Lines changed: 55 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,36 @@ def __init__(self, *args, **kwargs):
2323
args = list(args[0].items())
2424
else:
2525
args = list(args[0])
26+
for arg in args:
27+
if not len(arg) == 2:
28+
raise TypeError(
29+
"MultiDict takes either dict or list of \
30+
(key, value) tuples"
31+
)
2632

2733
self._fill(chain(args, kwargs.items()))
2834

2935
def _fill(self, ipairs):
3036
self._items.extend(ipairs)
3137

3238
def getall(self, key, default=_marker):
33-
"""Returns all values stored at key as a tuple.
34-
35-
Raises KeyError if key doesn't exist."""
36-
if key in self._items:
37-
return tuple(self._items[key])
38-
else:
39-
if default is not _marker:
40-
return default
41-
else:
42-
raise KeyError(key)
43-
44-
def getone(self, key):
39+
"""
40+
Return a list of all values matching the key (may be an empty list)
41+
"""
42+
res = tuple([v for k, v in self._items if k == key])
43+
if res:
44+
return res
45+
if not res and default != _marker:
46+
return default
47+
raise KeyError('Key not found: %r' % key)
48+
49+
def getone(self, key, default=_marker):
4550
"""
4651
Get one value matching the key, raising a KeyError if multiple
4752
values were found.
4853
"""
49-
v = self.getall(key)
50-
if not v:
51-
raise KeyError('Key not found: %r' % key)
52-
if len(v) > 1:
54+
v = self.getall(key, default=_marker)
55+
if len(v) > 1 and v != default:
5356
raise KeyError('Multiple values match %r: %r' % (key, v))
5457
return v[0]
5558

@@ -63,7 +66,7 @@ def copy(self):
6366
# Mapping interface #
6467

6568
def __getitem__(self, key):
66-
for k, v in reversed(self._items):
69+
for k, v in self._items:
6770
if k == key:
6871
return v
6972
raise KeyError(key)
@@ -74,14 +77,14 @@ def __iter__(self):
7477
def __len__(self):
7578
return len(self._items)
7679

77-
def keys(self, *, getall=False):
78-
return _KeysView(self._items, getall=getall)
80+
def keys(self):
81+
return _KeysView(self._items)
7982

80-
def items(self, *, getall=False):
81-
return _ItemsView(self._items, getall=getall)
83+
def items(self):
84+
return _ItemsView(self._items)
8285

83-
def values(self, *, getall=False):
84-
return _ValuesView(self._items, getall=getall)
86+
def values(self):
87+
return _ValuesView(self._items)
8588

8689
def __eq__(self, other):
8790
if not isinstance(other, abc.Mapping):
@@ -99,7 +102,7 @@ def __contains__(self, key):
99102
def __repr__(self):
100103
return '<{}>\n{}'.format(
101104
self.__class__.__name__, pprint.pformat(
102-
list(self.items(getall=True))))
105+
list(self.items())))
103106

104107

105108
class CaseInsensitiveMultiDict(MultiDict):
@@ -114,20 +117,14 @@ def _from_uppercase_multidict(cls, dct):
114117

115118
def _fill(self, ipairs):
116119
for key, value in ipairs:
117-
key = key.upper()
118-
if key in self._items:
119-
self._items[key].append(value)
120-
else:
121-
self._items[key] = [value]
120+
uppkey = key.upper()
121+
self._items.append((uppkey, value))
122122

123123
def getall(self, key, default=_marker):
124124
return super().getall(key.upper(), default)
125125

126-
def get(self, key, default=None):
127-
return self.get(key.upper(), default)
128-
129-
def getone(self, key):
130-
return self._items[key.upper()][0]
126+
def getone(self, key, default=_marker):
127+
return super().getone(key.upper(), default)
131128

132129
def __getitem__(self, key):
133130
return super().__getitem__(key.upper())
@@ -136,18 +133,7 @@ def __contains__(self, key):
136133
return super().__contains__(key.upper())
137134

138135

139-
class BaseMutableMultiDict(abc.MutableMapping):
140-
141-
def getall(self, key, default=_marker):
142-
"""Returns all values stored at key as list.
143-
144-
Raises KeyError if key doesn't exist.
145-
"""
146-
result = super().getall(key, default)
147-
if result is not default:
148-
return list(result)
149-
else:
150-
return result
136+
class MutableMultiDictMixin(abc.MutableMapping):
151137

152138
def add(self, key, value):
153139
"""
@@ -211,17 +197,14 @@ def update(self, *args, **kw):
211197
raise NotImplementedError("Use extend method instead")
212198

213199

214-
class MutableMultiDict(BaseMutableMultiDict, MultiDict):
200+
class MutableMultiDict(MutableMultiDictMixin, MultiDict):
215201
"""An ordered dictionary that can have multiple values for each key."""
216202

217203

218204
class CaseInsensitiveMutableMultiDict(
219-
BaseMutableMultiDict, CaseInsensitiveMultiDict):
205+
MutableMultiDictMixin, CaseInsensitiveMultiDict):
220206
"""An ordered dictionary that can have multiple values for each key."""
221207

222-
def getall(self, key, default=_marker):
223-
return super().getall(key.upper(), default)
224-
225208
def add(self, key, value):
226209
super().add(key.upper(), value)
227210

@@ -232,31 +215,34 @@ def __delitem__(self, key):
232215
super().__delitem__(key.upper())
233216

234217

235-
class _KeysView(abc.ItemsView):
218+
class _ItemsView(abc.ItemsView):
236219

237-
def __init__(self, items, *, getall=False):
238-
super().__init__(items)
239-
self._getall = getall
240-
# TBD
220+
def __contains__(self, item):
221+
assert isinstance(item, tuple) or isinstance(item, list)
222+
assert len(item) == 2
223+
return item in self._mapping
241224

225+
def __iter__(self):
226+
yield from self._mapping
242227

243-
class _ItemsView(abc.ItemsView):
244228

245-
def __init__(self, items, *, getall=False):
246-
super().__init__(items)
247-
self._getall = getall
229+
class _ValuesView(abc.ValuesView):
248230

249-
def __contains__(self, item):
250-
# TBD
251-
pass
231+
def __contains__(self, value):
232+
values = [item[1] for item in self._mapping]
233+
return value in values
252234

253235
def __iter__(self):
254-
pass
236+
values = [item[1] for item in self._mapping]
237+
yield from values
255238

256239

257-
class _ValuesView(abc.KeysView):
240+
class _KeysView(abc.KeysView):
258241

259-
def __init__(self, mapping, *, getall=False):
260-
super().__init__(mapping)
261-
self._getall = getall
262-
# TBD
242+
def __contains__(self, key):
243+
keys = set([item[0] for item in self._mapping])
244+
return key in keys
245+
246+
def __iter__(self):
247+
keys = set([item[0] for item in self._mapping])
248+
yield from keys

0 commit comments

Comments
 (0)