Skip to content

Commit 920f1cb

Browse files
committed
Make str or upstr mandatory for multidict keys
1 parent dad10ed commit 920f1cb

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

aiohttp/_multidict.pyx

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,18 @@ cdef class _Base:
3434
def __cinit__(self):
3535
self._upstr = upstr
3636

37-
cdef _upper(self, key):
38-
return key
37+
cdef str _upper(self, s):
38+
if type(s) is self._upstr:
39+
return <str>s
40+
return s
3941

4042
def getall(self, key, default=_marker):
4143
"""
4244
Return a list of all values matching the key (may be an empty list)
4345
"""
44-
return self._getall(key, default)
46+
return self._getall(self._upper(key), default)
4547

46-
cdef _getall(self, key, default):
48+
cdef _getall(self, str key, default):
4749
cdef list res
4850
key = self._upper(key)
4951
res = [v for k, v in self._items if k == key]
@@ -57,13 +59,13 @@ cdef class _Base:
5759
"""
5860
Get first value matching the key
5961
"""
60-
return self._getone(key, default)
62+
return self._getone(self._upper(key), default)
6163

62-
cdef _getone(self, key, default):
64+
cdef _getone(self, str key, default):
6365
cdef tuple item
6466
key = self._upper(key)
6567
for item in self._items:
66-
if item[0] == key:
68+
if <str>item[0] == key:
6769
return item[1]
6870
if default is not _marker:
6971
return default
@@ -72,19 +74,19 @@ cdef class _Base:
7274
# Mapping interface #
7375

7476
def __getitem__(self, key):
75-
return self._getone(key, _marker)
77+
return self._getone(self._upper(key), _marker)
7678

7779
def get(self, key, default=None):
78-
return self._getone(key, default)
80+
return self._getone(self._upper(key), default)
7981

8082
def __contains__(self, key):
81-
return self._contains(key)
83+
return self._contains(self._upper(key))
8284

83-
cdef _contains(self, key):
85+
cdef _contains(self, str key):
8486
cdef tuple item
8587
key = self._upper(key)
8688
for item in self._items:
87-
if item[0] == key:
89+
if <str>item[0] == key:
8890
return True
8991
return False
9092

@@ -177,9 +179,9 @@ cdef class CIMultiDictProxy(MultiDictProxy):
177179
mdict = arg
178180
self._items = mdict._items
179181

180-
cdef _upper(self, s):
182+
cdef str _upper(self, s):
181183
if type(s) is self._upstr:
182-
return s
184+
return <str>s
183185
return s.upper()
184186

185187
def copy(self):
@@ -199,6 +201,7 @@ cdef class MultiDict(_Base):
199201

200202
cdef _extend(self, tuple args, dict kwargs, name, int do_add):
201203
cdef tuple item
204+
cdef str key
202205

203206
if len(args) > 1:
204207
raise TypeError("{} takes at most 1 positional argument"
@@ -234,10 +237,10 @@ cdef class MultiDict(_Base):
234237
else:
235238
self._replace(key, value)
236239

237-
cdef _add(self, key, value):
240+
cdef _add(self, str key, value):
238241
self._items.append((key, value))
239242

240-
cdef _replace(self, key, value):
243+
cdef _replace(self, str key, value):
241244
self._remove(key, 0)
242245
self._items.append((key, value))
243246

@@ -266,15 +269,12 @@ cdef class MultiDict(_Base):
266269
# MutableMapping interface #
267270

268271
def __setitem__(self, key, value):
269-
key = self._upper(key)
270-
self._remove(key, False)
271-
self._add(key, value)
272+
self._replace(self._upper(key), value)
272273

273274
def __delitem__(self, key):
274-
key = self._upper(key)
275-
self._remove(key, True)
275+
self._remove(self._upper(key), True)
276276

277-
cdef _remove(self, key, int raise_key_error):
277+
cdef _remove(self, str key, int raise_key_error):
278278
cdef int found
279279
found = False
280280
for i in range(len(self._items) - 1, -1, -1):
@@ -285,17 +285,19 @@ cdef class MultiDict(_Base):
285285
raise KeyError(key)
286286

287287
def setdefault(self, key, default=None):
288-
key = self._upper(key)
288+
cdef str skey
289+
skey = self._upper(key)
289290
for k, v in self._items:
290-
if k == key:
291+
if k == skey:
291292
return v
292-
self._add(key, default)
293+
self._add(skey, default)
293294
return default
294295

295296
def pop(self, key, default=_marker):
296297
cdef int found
298+
cdef str skey
297299
cdef object value
298-
key = self._upper(key)
300+
skey = self._upper(key)
299301
value = None
300302
found = False
301303
for i in range(len(self._items) - 1, -1, -1):
@@ -357,9 +359,9 @@ abc.MutableMapping.register(MultiDict)
357359
cdef class CIMultiDict(MultiDict):
358360
"""An ordered dictionary that can have multiple values for each key."""
359361

360-
cdef _upper(self, s):
362+
cdef str _upper(self, s):
361363
if type(s) is self._upstr:
362-
return s
364+
return <str>s
363365
return s.upper()
364366

365367

0 commit comments

Comments
 (0)