Skip to content

Commit 7642044

Browse files
committed
Add MapMutation.update(); make creating Map from a Map faster; fix bugs
* Add new MapMutation.update() method that behaves like MutableMapping.update() * Make it faster to create a Map() from another Map() -- it's now an O(1) operation. * update() method had a bug that could cause the update Map object to have a wrong number of elements.
1 parent 5a9c2fa commit 7642044

File tree

3 files changed

+160
-13
lines changed

3 files changed

+160
-13
lines changed

immutables/_map.c

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ map_node_update(uint64_t mutid,
387387

388388

389389
static int
390-
map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src);
390+
map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src);
391391

392392
static MapObject *
393393
map_update(uint64_t mutid, MapObject *o, PyObject *src);
@@ -2153,6 +2153,8 @@ map_node_assoc(MapNode *node,
21532153
map_node_{nodetype}_assoc method.
21542154
*/
21552155

2156+
*added_leaf = 0;
2157+
21562158
if (IS_BITMAP_NODE(node)) {
21572159
return map_node_bitmap_assoc(
21582160
(MapNode_Bitmap *)node,
@@ -2892,10 +2894,27 @@ map_tp_init(MapObject *self, PyObject *args, PyObject *kwds)
28922894
}
28932895

28942896
if (arg != NULL) {
2895-
mutid = mutid_counter++;
2896-
if (map_update_inplace(mutid, self, arg)) {
2897+
if (Map_Check(arg)) {
2898+
MapObject *other = (MapObject *)arg;
2899+
2900+
Py_INCREF(other->h_root);
2901+
Py_SETREF(self->h_root, other->h_root);
2902+
2903+
self->h_count = other->h_count;
2904+
self->h_hash = other->h_hash;
2905+
}
2906+
else if (MapMutation_Check(arg)) {
2907+
PyErr_Format(
2908+
PyExc_TypeError,
2909+
"cannot create Maps from MapMutations");
28972910
return -1;
28982911
}
2912+
else {
2913+
mutid = mutid_counter++;
2914+
if (map_update_inplace(mutid, (BaseMapObject *)self, arg)) {
2915+
return -1;
2916+
}
2917+
}
28992918
}
29002919

29012920
if (kwds != NULL) {
@@ -2907,7 +2926,7 @@ map_tp_init(MapObject *self, PyObject *args, PyObject *kwds)
29072926
mutid = mutid_counter++;
29082927
}
29092928

2910-
if (map_update_inplace(mutid, self, kwds)) {
2929+
if (map_update_inplace(mutid, (BaseMapObject *)self, kwds)) {
29112930
return -1;
29122931
}
29132932
}
@@ -3665,14 +3684,14 @@ map_node_update(uint64_t mutid,
36653684

36663685

36673686
static int
3668-
map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src)
3687+
map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src)
36693688
{
36703689
MapNode *new_root = NULL;
36713690
Py_ssize_t new_count;
36723691

36733692
int ret = map_node_update(
36743693
mutid, src,
3675-
o->h_root, o->h_count,
3694+
o->b_root, o->b_count,
36763695
&new_root, &new_count);
36773696

36783697
if (ret) {
@@ -3681,8 +3700,8 @@ map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src)
36813700

36823701
assert(new_root);
36833702

3684-
Py_SETREF(o->h_root, new_root);
3685-
o->h_count = new_count;
3703+
Py_SETREF(o->b_root, new_root);
3704+
o->b_count = new_count;
36863705

36873706
return 0;
36883707
}
@@ -3852,6 +3871,35 @@ mapmut_tp_richcompare(PyObject *v, PyObject *w, int op)
38523871
}
38533872
}
38543873

3874+
static PyObject *
3875+
mapmut_py_update(MapMutationObject *self, PyObject *args, PyObject *kwds)
3876+
{
3877+
PyObject *arg = NULL;
3878+
3879+
if (!PyArg_UnpackTuple(args, "update", 0, 1, &arg)) {
3880+
return NULL;
3881+
}
3882+
3883+
if (arg != NULL) {
3884+
if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, arg)) {
3885+
return NULL;
3886+
}
3887+
}
3888+
3889+
if (kwds != NULL) {
3890+
if (!PyArg_ValidateKeywordArguments(kwds)) {
3891+
return NULL;
3892+
}
3893+
3894+
if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, kwds)) {
3895+
return NULL;
3896+
}
3897+
}
3898+
3899+
Py_RETURN_NONE;
3900+
}
3901+
3902+
38553903
static PyObject *
38563904
mapmut_py_finalize(MapMutationObject *self, PyObject *args)
38573905
{
@@ -3970,6 +4018,8 @@ static PyMethodDef MapMutation_methods[] = {
39704018
{"get", (PyCFunction)map_py_get, METH_VARARGS, NULL},
39714019
{"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL},
39724020
{"finish", (PyCFunction)mapmut_py_finalize, METH_NOARGS, NULL},
4021+
{"update", (PyCFunction)mapmut_py_update,
4022+
METH_VARARGS | METH_KEYWORDS, NULL},
39734023
{"__enter__", (PyCFunction)mapmut_py_enter, METH_NOARGS, NULL},
39744024
{"__exit__", (PyCFunction)mapmut_py_exit, METH_VARARGS, NULL},
39754025
{NULL, NULL}

immutables/map.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,14 @@ def __init__(self, col=None, **kw):
438438
self.__root = BitmapNode(0, 0, [], 0)
439439
self.__hash = -1
440440

441+
if isinstance(col, Map):
442+
self.__count = col.__count
443+
self.__root = col.__root
444+
self.__hash = col.__hash
445+
col = None
446+
elif isinstance(col, MapMutation):
447+
raise TypeError('cannot create Maps from MapMutations')
448+
441449
if col or kw:
442450
init = self.update(col, **kw)
443451
self.__count = init.__count
@@ -640,6 +648,9 @@ def __exit__(self, *exc):
640648
self.finish()
641649
return False
642650

651+
def __iter__(self):
652+
raise TypeError('{} is not iterable'.format(type(self)))
653+
643654
def __delitem__(self, key):
644655
if self.__mutid == 0:
645656
raise ValueError('mutation {!r} has been finished'.format(self))
@@ -707,6 +718,56 @@ def __contains__(self, key):
707718
else:
708719
return True
709720

721+
def update(self, col=None, **kw):
722+
it = None
723+
if col is not None:
724+
if hasattr(col, 'items'):
725+
it = iter(col.items())
726+
else:
727+
it = iter(col)
728+
729+
if it is not None:
730+
if kw:
731+
it = iter(itertools.chain(it, kw.items()))
732+
else:
733+
if kw:
734+
it = iter(kw.items())
735+
736+
if it is None:
737+
738+
return self
739+
740+
root = self.__root
741+
count = self.__count
742+
743+
i = 0
744+
while True:
745+
try:
746+
tup = next(it)
747+
except StopIteration:
748+
break
749+
750+
try:
751+
tup = tuple(tup)
752+
except TypeError:
753+
raise TypeError(
754+
'cannot convert map update '
755+
'sequence element #{} to a sequence'.format(i)) from None
756+
key, val, *r = tup
757+
if r:
758+
raise ValueError(
759+
'map update sequence element #{} has length '
760+
'{}; 2 is required'.format(i, len(r) + 2))
761+
762+
root, added = root.assoc(0, map_hash(key), key, val, self.__mutid)
763+
if added:
764+
count += 1
765+
766+
i += 1
767+
768+
self.__root = root
769+
self.__count = count
770+
710771
def finish(self):
711772
self.__mutid = 0
712773
return Map._new(self.__count, self.__root)

tests/test_map.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,11 +1089,6 @@ def test_map_mut_9(self):
10891089
with self.assertRaises(HashingError):
10901090
self.Map(src)
10911091

1092-
src = self.Map({key1: 123})
1093-
with HashKeyCrasher(error_on_hash=True):
1094-
with self.assertRaises(HashingError):
1095-
self.Map(src)
1096-
10971092
src = [(1, 2), (key1, 123)]
10981093
with HashKeyCrasher(error_on_hash=True):
10991094
with self.assertRaises(HashingError):
@@ -1195,6 +1190,47 @@ def test_map_mut_15(self):
11951190
self.assertEqual(mm.finish(), self.Map(z=100, b=2))
11961191
self.assertEqual(m, self.Map(a=1, b=2))
11971192

1193+
def test_map_mut_16(self):
1194+
m = self.Map(a=1, b=2)
1195+
hash(m)
1196+
1197+
m2 = self.Map(m)
1198+
m3 = self.Map(m, c=3)
1199+
1200+
self.assertEqual(m, m2)
1201+
self.assertEqual(len(m), len(m2))
1202+
self.assertEqual(hash(m), hash(m2))
1203+
1204+
self.assertIsNot(m, m2)
1205+
self.assertEqual(m3, self.Map(a=1, b=2, c=3))
1206+
1207+
def test_map_mut_17(self):
1208+
m = self.Map(a=1)
1209+
with m.mutate() as mm:
1210+
with self.assertRaisesRegex(
1211+
TypeError, 'cannot create Maps from MapMutations'):
1212+
self.Map(mm)
1213+
1214+
def test_map_mut_18(self):
1215+
m = self.Map(a=1, b=2)
1216+
with m.mutate() as mm:
1217+
mm.update(self.Map(x=1), z=2)
1218+
mm.update(c=3)
1219+
mm.update({'n': 100, 'a': 20})
1220+
m2 = mm.finish()
1221+
1222+
expected = self.Map(
1223+
{'b': 2, 'c': 3, 'n': 100, 'z': 2, 'x': 1, 'a': 20})
1224+
1225+
self.assertEqual(len(m2), 6)
1226+
self.assertEqual(m2, expected)
1227+
self.assertEqual(m, self.Map(a=1, b=2))
1228+
1229+
def test_map_mut_19(self):
1230+
m = self.Map(a=1, b=2)
1231+
m2 = m.update({'a': 20})
1232+
self.assertEqual(len(m2), 2)
1233+
11981234
def test_map_mut_stress(self):
11991235
COLLECTION_SIZE = 7000
12001236
TEST_ITERS_EVERY = 647

0 commit comments

Comments
 (0)