Skip to content

Commit 50a58d4

Browse files
authored
Restore multidict internal state on exception during updating (#1215)
1 parent 0b046c9 commit 50a58d4

File tree

5 files changed

+61
-42
lines changed

5 files changed

+61
-42
lines changed

CHANGES/1215.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
If :meth:`multidict.MultiDict.extend`, :meth:`multidict.MultiDict.merge`, or :meth:`multidict.MultiDict.update` raises an exception, now the multidict internal state is correctly restored.
2+
Patch by :user:`asvetlov`.

multidict/_multidict.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,18 @@ _multidict_extend(MultiDictObject *self, PyObject *arg, PyObject *kwds,
118118
}
119119

120120
if (op != Extend) { // Update or Merge
121-
if (md_post_update(self) < 0) {
122-
goto fail;
123-
}
121+
md_post_update(self);
124122
}
125123

126124
ASSERT_CONSISTENT(self, false);
127125
Py_CLEAR(seq);
128126
return 0;
129127
fail:
128+
if (op != Extend) { // Update or Merge
129+
// Cleanup soft-deleted items
130+
md_post_update(self);
131+
}
132+
ASSERT_CONSISTENT(self, false);
130133
Py_CLEAR(seq);
131134
return -1;
132135
}

multidict/_multidict_py.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,11 @@ def _identity(self, key: str) -> str:
445445
if isinstance(key, istr):
446446
ret = key.__istr_identity__
447447
if ret is None:
448-
ret = key.title()
448+
ret = key.lower()
449449
key.__istr_identity__ = ret
450450
return ret
451451
if isinstance(key, str):
452-
return key.title()
452+
return key.lower()
453453
else:
454454
raise TypeError("MultiDict keys should be either str or subclasses of str")
455455

@@ -632,13 +632,13 @@ def __init__(self, arg: MDArg[_V] = None, /, **kwargs: _V):
632632
self._from_md(md)
633633
return
634634

635-
items = self._parse_args(arg, kwargs)
636-
log2_size = estimate_log2_keysize(len(items))
635+
it = self._parse_args(arg, kwargs)
636+
log2_size = estimate_log2_keysize(cast(int, next(it)))
637637
if log2_size > 17: # pragma: no cover
638638
# Don't overallocate really huge keys space in init
639639
log2_size = 17
640640
self._keys: _HtKeys[_V] = _HtKeys.new(log2_size, [])
641-
self._extend_items(items)
641+
self._extend_items(cast(Iterator[_Entry[_V]], it))
642642

643643
def _from_md(self, md: "MultiDict[_V]") -> None:
644644
# Copy everything as-is without compacting the new multidict,
@@ -790,58 +790,57 @@ def extend(self, arg: MDArg[_V] = None, /, **kwargs: _V) -> None:
790790
791791
This method must be used instead of update.
792792
"""
793-
items = self._parse_args(arg, kwargs)
794-
newsize = self._used + len(items)
793+
it = self._parse_args(arg, kwargs)
794+
newsize = self._used + cast(int, next(it))
795795
self._resize(estimate_log2_keysize(newsize), False)
796-
self._extend_items(items)
796+
self._extend_items(cast(Iterator[_Entry[_V]], it))
797797

798798
def _parse_args(
799799
self,
800800
arg: MDArg[_V],
801801
kwargs: Mapping[str, _V],
802-
) -> list[_Entry[_V]]:
802+
) -> Iterator[Union[int, _Entry[_V]]]:
803803
identity_func = self._identity
804804
if arg:
805805
if isinstance(arg, MultiDictProxy):
806806
arg = arg._md
807807
if isinstance(arg, MultiDict):
808+
yield len(arg) + len(kwargs)
808809
if self._ci is not arg._ci:
809-
items = []
810810
for e in arg._keys.iter_entries():
811811
identity = identity_func(e.key)
812-
items.append(_Entry(hash(identity), identity, e.key, e.value))
812+
yield _Entry(hash(identity), identity, e.key, e.value)
813813
else:
814-
items = [
815-
_Entry(e.hash, e.identity, e.key, e.value)
816-
for e in arg._keys.iter_entries()
817-
]
814+
for e in arg._keys.iter_entries():
815+
yield _Entry(e.hash, e.identity, e.key, e.value)
818816
if kwargs:
819817
for key, value in kwargs.items():
820818
identity = identity_func(key)
821-
items.append(_Entry(hash(identity), identity, key, value))
819+
yield _Entry(hash(identity), identity, key, value)
822820
else:
823821
if hasattr(arg, "keys"):
824822
arg = cast(SupportsKeys[_V], arg)
825823
arg = [(k, arg[k]) for k in arg.keys()]
826824
if kwargs:
827825
arg = list(arg)
828826
arg.extend(list(kwargs.items()))
829-
items = []
827+
try:
828+
yield len(arg) + len(kwargs) # type: ignore[arg-type]
829+
except TypeError:
830+
yield 0
830831
for pos, item in enumerate(arg):
831832
if not len(item) == 2:
832833
raise ValueError(
833834
f"multidict update sequence element #{pos}"
834835
f"has length {len(item)}; 2 is required"
835836
)
836837
identity = identity_func(item[0])
837-
items.append(_Entry(hash(identity), identity, item[0], item[1]))
838+
yield _Entry(hash(identity), identity, item[0], item[1])
838839
else:
839-
items = []
840+
yield len(kwargs)
840841
for key, value in kwargs.items():
841842
identity = identity_func(key)
842-
items.append(_Entry(hash(identity), identity, key, value))
843-
844-
return items
843+
yield _Entry(hash(identity), identity, key, value)
845844

846845
def _extend_items(self, items: Iterable[_Entry[_V]]) -> None:
847846
for e in items:
@@ -989,19 +988,21 @@ def popitem(self) -> tuple[str, _V]:
989988

990989
def update(self, arg: MDArg[_V] = None, /, **kwargs: _V) -> None:
991990
"""Update the dictionary, overwriting existing keys."""
992-
items = self._parse_args(arg, kwargs)
993-
newsize = self._used + len(items)
991+
it = self._parse_args(arg, kwargs)
992+
newsize = self._used + cast(int, next(it))
994993
log2_size = estimate_log2_keysize(newsize)
995994
if log2_size > 17: # pragma: no cover
996995
# Don't overallocate really huge keys space in update,
997996
# duplicate keys could reduce the resulting anount of entries
998997
log2_size = 17
999998
if log2_size > self._keys.log2_size:
1000999
self._resize(log2_size, False)
1001-
self._update_items(items)
1002-
self._post_update()
1000+
try:
1001+
self._update_items(cast(Iterator[_Entry[_V]], it))
1002+
finally:
1003+
self._post_update()
10031004

1004-
def _update_items(self, items: list[_Entry[_V]]) -> None:
1005+
def _update_items(self, items: Iterator[_Entry[_V]]) -> None:
10051006
for entry in items:
10061007
found = False
10071008
hash_ = entry.hash
@@ -1038,19 +1039,21 @@ def _post_update(self) -> None:
10381039

10391040
def merge(self, arg: MDArg[_V] = None, /, **kwargs: _V) -> None:
10401041
"""Merge into the dictionary, adding non-existing keys."""
1041-
items = self._parse_args(arg, kwargs)
1042-
newsize = self._used + len(items)
1042+
it = self._parse_args(arg, kwargs)
1043+
newsize = self._used + cast(int, next(it))
10431044
log2_size = estimate_log2_keysize(newsize)
10441045
if log2_size > 17: # pragma: no cover
10451046
# Don't overallocate really huge keys space in update,
10461047
# duplicate keys could reduce the resulting anount of entries
10471048
log2_size = 17
10481049
if log2_size > self._keys.log2_size:
10491050
self._resize(log2_size, False)
1050-
self._merge_items(items)
1051-
self._post_update()
1051+
try:
1052+
self._merge_items(cast(Iterator[_Entry[_V]], it))
1053+
finally:
1054+
self._post_update()
10521055

1053-
def _merge_items(self, items: list[_Entry[_V]]) -> None:
1056+
def _merge_items(self, items: Iterator[_Entry[_V]]) -> None:
10541057
for entry in items:
10551058
hash_ = entry.hash
10561059
identity = entry.identity

multidict/_multilib/hashtable.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ _ci_key_to_identity(mod_state *state, PyObject *key)
151151
}
152152
return ret;
153153
}
154-
fail:
155154
PyErr_SetString(PyExc_TypeError,
156155
"CIMultiDict keys should be either str "
157156
"or subclasses of str");
157+
fail:
158158
return NULL;
159159
}
160160

@@ -1287,7 +1287,7 @@ _md_merge(MultiDictObject *md, Py_hash_t hash, PyObject *identity,
12871287
return -1;
12881288
}
12891289

1290-
static inline int
1290+
static inline void
12911291
md_post_update(MultiDictObject *md)
12921292
{
12931293
htkeys_t *keys = md->keys;
@@ -1306,14 +1306,11 @@ md_post_update(MultiDictObject *md)
13061306
}
13071307
if (entry->hash == -1) {
13081308
entry->hash = _unicode_hash(entry->identity);
1309-
if (entry->hash == -1) {
1310-
// hash of string always exists but still
1311-
return -1;
1312-
}
13131309
}
1310+
assert(entry->hash != -1);
13141311
}
13151312
}
1316-
return 0;
1313+
ASSERT_CONSISTENT(md, false);
13171314
}
13181315

13191316
static inline int

tests/test_mutable_multidict.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,17 @@ def test_issue_1195(
858858
assert md.keys() == md2.keys() - {"User-Agent"}
859859
md.update([("User-Agent", b"Bacon/1.0")])
860860
assert md.keys() == md2.keys()
861+
862+
def test_update_with_crash_in_the_middle(
863+
self, case_insensitive_multidict_class: type[CIMultiDict[str]]
864+
) -> None:
865+
class Hack(str):
866+
def lower(self) -> str:
867+
raise RuntimeError
868+
869+
d = case_insensitive_multidict_class([("a", "a"), ("b", "b")])
870+
with pytest.raises(RuntimeError):
871+
lst = [("c", "c"), ("a", "a2"), (Hack("b"), "b2")]
872+
d.update(lst)
873+
874+
assert [("a", "a2"), ("b", "b"), ("c", "c")] == list(d.items())

0 commit comments

Comments
 (0)