Skip to content

Commit 58d84ab

Browse files
authored
Merge pull request #8828 from tannewt/fix_dict_subclass
Fix subclassing dict
2 parents 9efe5a2 + 283aac2 commit 58d84ab

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

py/objdict.c

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ const mp_obj_dict_t mp_const_empty_dict_obj = {
4949
}
5050
};
5151

52+
// CIRCUITPY-CHANGE: Native methods are passed the subclass instance so they can
53+
// refer to subclass members. Dict only cares about the native struct so this
54+
// function gets it.
55+
STATIC mp_obj_dict_t *native_dict(mp_obj_t self_in) {
56+
// Check for OrderedDict first because it is marked as a subclass of dict. However, it doesn't
57+
// store its state in subobj like python types to native types do.
58+
mp_obj_t native_instance = MP_OBJ_NULL;
59+
#if MICROPY_PY_COLLECTIONS_ORDEREDDICT
60+
native_instance = mp_obj_cast_to_native_base(self_in, MP_OBJ_FROM_PTR(&mp_type_ordereddict));
61+
#endif
62+
if (native_instance == MP_OBJ_NULL) {
63+
native_instance = mp_obj_cast_to_native_base(self_in, MP_OBJ_FROM_PTR(&mp_type_dict));
64+
}
65+
return MP_OBJ_TO_PTR(native_instance);
66+
}
67+
5268
STATIC mp_obj_t dict_update(size_t n_args, const mp_obj_t *args, mp_map_t *kwargs);
5369

5470
// This is a helper function to iterate through a dictionary. The state of
@@ -71,7 +87,7 @@ STATIC mp_map_elem_t *dict_iter_next(mp_obj_dict_t *dict, size_t *cur) {
7187
}
7288

7389
STATIC void dict_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) {
74-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
90+
mp_obj_dict_t *self = native_dict(self_in);
7591
bool first = true;
7692
const char *item_separator = ", ";
7793
const char *key_separator = ": ";
@@ -144,7 +160,7 @@ mp_obj_t mp_obj_dict_make_new(const mp_obj_type_t *type, size_t n_args, size_t n
144160
}
145161

146162
STATIC mp_obj_t dict_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
147-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
163+
mp_obj_dict_t *self = native_dict(self_in);
148164
switch (op) {
149165
case MP_UNARY_OP_BOOL:
150166
return mp_obj_new_bool(self->map.used != 0);
@@ -162,7 +178,7 @@ STATIC mp_obj_t dict_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
162178
}
163179

164180
STATIC mp_obj_t dict_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
165-
mp_obj_dict_t *o = MP_OBJ_TO_PTR(lhs_in);
181+
mp_obj_dict_t *o = native_dict(lhs_in);
166182
switch (op) {
167183
case MP_BINARY_OP_CONTAINS: {
168184
mp_map_elem_t *elem = mp_map_lookup(&o->map, rhs_in, MP_MAP_LOOKUP);
@@ -223,7 +239,7 @@ STATIC mp_obj_t dict_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_
223239

224240
// Note: Make sure this is inlined in load part of dict_subscr() below.
225241
mp_obj_t mp_obj_dict_get(mp_obj_t self_in, mp_obj_t index) {
226-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
242+
mp_obj_dict_t *self = native_dict(self_in);
227243
mp_map_elem_t *elem = mp_map_lookup(&self->map, index, MP_MAP_LOOKUP);
228244
if (elem == NULL) {
229245
mp_raise_type_arg(&mp_type_KeyError, index);
@@ -239,7 +255,7 @@ STATIC mp_obj_t dict_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
239255
return mp_const_none;
240256
} else if (value == MP_OBJ_SENTINEL) {
241257
// load
242-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
258+
mp_obj_dict_t *self = native_dict(self_in);
243259
mp_map_elem_t *elem = mp_map_lookup(&self->map, index, MP_MAP_LOOKUP);
244260
if (elem == NULL) {
245261
mp_raise_type_arg(&mp_type_KeyError, index);
@@ -264,7 +280,7 @@ STATIC void PLACE_IN_ITCM(mp_ensure_not_fixed)(const mp_obj_dict_t * dict) {
264280

265281
STATIC mp_obj_t dict_clear(mp_obj_t self_in) {
266282
mp_check_self(mp_obj_is_dict_or_ordereddict(self_in));
267-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
283+
mp_obj_dict_t *self = native_dict(self_in);
268284
mp_ensure_not_fixed(self);
269285

270286
mp_map_clear(&self->map);
@@ -275,9 +291,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_1(dict_clear_obj, dict_clear);
275291

276292
mp_obj_t mp_obj_dict_copy(mp_obj_t self_in) {
277293
mp_check_self(mp_obj_is_dict_or_ordereddict(self_in));
278-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
294+
mp_obj_dict_t *self = native_dict(self_in);
279295
mp_obj_t other_out = mp_obj_new_dict(self->map.alloc);
280-
mp_obj_dict_t *other = MP_OBJ_TO_PTR(other_out);
296+
mp_obj_dict_t *other = native_dict(other_out);
281297
other->base.type = self->base.type;
282298
other->map.used = self->map.used;
283299
other->map.all_keys_are_qstrs = self->map.all_keys_are_qstrs;
@@ -324,7 +340,7 @@ STATIC MP_DEFINE_CONST_CLASSMETHOD_OBJ(dict_fromkeys_obj, MP_ROM_PTR(&dict_fromk
324340

325341
STATIC mp_obj_t dict_get_helper(size_t n_args, const mp_obj_t *args, mp_map_lookup_kind_t lookup_kind) {
326342
mp_check_self(mp_obj_is_dict_or_ordereddict(args[0]));
327-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(args[0]);
343+
mp_obj_dict_t *self = native_dict(args[0]);
328344
if (lookup_kind != MP_MAP_LOOKUP) {
329345
mp_ensure_not_fixed(self);
330346
}
@@ -369,7 +385,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(dict_setdefault_obj, 2, 3, dict_setde
369385

370386
STATIC mp_obj_t dict_popitem(mp_obj_t self_in) {
371387
mp_check_self(mp_obj_is_dict_or_ordereddict(self_in));
372-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
388+
mp_obj_dict_t *self = native_dict(self_in);
373389
mp_ensure_not_fixed(self);
374390
if (self->map.used == 0) {
375391
mp_raise_msg_varg(&mp_type_KeyError, MP_ERROR_TEXT("pop from empty %q"), MP_QSTR_dict);
@@ -394,7 +410,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_1(dict_popitem_obj, dict_popitem);
394410

395411
STATIC mp_obj_t dict_update(size_t n_args, const mp_obj_t *args, mp_map_t *kwargs) {
396412
mp_check_self(mp_obj_is_dict_or_ordereddict(args[0]));
397-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(args[0]);
413+
mp_obj_dict_t *self = native_dict(args[0]);
398414
mp_ensure_not_fixed(self);
399415

400416
mp_arg_check_num(n_args, kwargs->used, 1, 2, true);
@@ -726,7 +742,7 @@ size_t mp_obj_dict_len(mp_obj_t self_in) {
726742

727743
mp_obj_t mp_obj_dict_store(mp_obj_t self_in, mp_obj_t key, mp_obj_t value) {
728744
mp_check_self(mp_obj_is_dict_or_ordereddict(self_in));
729-
mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_in);
745+
mp_obj_dict_t *self = native_dict(self_in);
730746
mp_ensure_not_fixed(self);
731747
mp_map_lookup(&self->map, key, MP_MAP_LOOKUP_ADD_IF_NOT_FOUND)->value = value;
732748
return self_in;

py/opmethods.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,26 @@
2727
#include "py/obj.h"
2828
#include "py/builtin.h"
2929

30+
// CIRCUITPY-CHANGE: These three functions are used by dict only. In CP, we hard
31+
// code the type to dict so that subclassed types still use the native dict
32+
// subscr. MP doesn't have this problem because it passes the native instance
33+
// in. CP passes the subclass instance.
3034
STATIC mp_obj_t op_getitem(mp_obj_t self_in, mp_obj_t key_in) {
31-
const mp_obj_type_t *type = mp_obj_get_type(self_in);
35+
const mp_obj_type_t *type = &mp_type_dict;
3236
// Note: assumes type must have subscr (only used by dict).
3337
return MP_OBJ_TYPE_GET_SLOT(type, subscr)(self_in, key_in, MP_OBJ_SENTINEL);
3438
}
3539
MP_DEFINE_CONST_FUN_OBJ_2(mp_op_getitem_obj, op_getitem);
3640

3741
STATIC mp_obj_t op_setitem(mp_obj_t self_in, mp_obj_t key_in, mp_obj_t value_in) {
38-
const mp_obj_type_t *type = mp_obj_get_type(self_in);
42+
const mp_obj_type_t *type = &mp_type_dict;
3943
// Note: assumes type must have subscr (only used by dict).
4044
return MP_OBJ_TYPE_GET_SLOT(type, subscr)(self_in, key_in, value_in);
4145
}
4246
MP_DEFINE_CONST_FUN_OBJ_3(mp_op_setitem_obj, op_setitem);
4347

4448
STATIC mp_obj_t op_delitem(mp_obj_t self_in, mp_obj_t key_in) {
45-
const mp_obj_type_t *type = mp_obj_get_type(self_in);
49+
const mp_obj_type_t *type = &mp_type_dict;
4650
// Note: assumes type must have subscr (only used by dict).
4751
return MP_OBJ_TYPE_GET_SLOT(type, subscr)(self_in, key_in, MP_OBJ_NULL);
4852
}

tests/basics/subclass_native_dict.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
class a:
2+
def __init__(self):
3+
self.d = {}
4+
5+
def __setitem__(self, k, v):
6+
print("a", k, v)
7+
self.d[k] = v
8+
9+
def __getitem__(self, k):
10+
return self.d[k]
11+
12+
class b(a):
13+
def __setitem__(self, k, v):
14+
print("b", k, v)
15+
super().__setitem__(k, v)
16+
17+
b1 = b()
18+
b1[1] = 2
19+
print(b1[1])
20+
21+
class mydict(dict):
22+
def __setitem__(self, k, v):
23+
print(k, v)
24+
super().__setitem__(k, v)
25+
26+
d = mydict()
27+
d[3] = 4
28+
print(d[3])

0 commit comments

Comments
 (0)