diff --git a/src/msgspec/_core.c b/src/msgspec/_core.c index d4151cd8..56d247ac 100644 --- a/src/msgspec/_core.c +++ b/src/msgspec/_core.c @@ -519,6 +519,7 @@ typedef struct { #endif PyObject *astimezone; PyObject *re_compile; + PyObject *copy_deepcopy; uint8_t gc_cycle; } MsgspecState; @@ -7938,6 +7939,66 @@ Struct_copy(PyObject *self, PyObject *args) return NULL; } + +static PyObject* get_deepcopy_func() { + // lazily copy.deepcopy and cache in global state + PyObject *copy_mod, *deepcopy_func; + MsgspecState* st = msgspec_get_global_state(); + deepcopy_func = st->copy_deepcopy; + if (deepcopy_func == NULL) { + copy_mod = PyImport_ImportModule("copy"); + if (copy_mod == NULL) return NULL; + deepcopy_func = PyObject_GetAttrString(copy_mod, "deepcopy"); + st->copy_deepcopy = deepcopy_func; + Py_DECREF(copy_mod); + if (st->copy_deepcopy == NULL) return NULL; + } + + return deepcopy_func; +} + +static PyObject * +Struct_deepcopy(PyObject *self, PyObject *args) +{ + PyObject *memo; + PyObject *val = NULL, *res = NULL, *dc_val = NULL; + PyObject *deepcopy_func; + Py_ssize_t i, nfields; + + if (!PyArg_ParseTuple(args, "O!:__deepcopy__", &PyDict_Type, &memo)) + return NULL; + + deepcopy_func = get_deepcopy_func(); + + res = Struct_alloc(Py_TYPE(self)); + if (res == NULL) + return NULL; + + nfields = StructMeta_GET_NFIELDS(Py_TYPE(self)); + for (i = 0; i < nfields; i++) { + val = Struct_get_index(self, i); + if (val == NULL) + goto error; + + dc_val = PyObject_CallFunctionObjArgs(deepcopy_func, val, memo, NULL); + if (dc_val == NULL) + goto error; + + Struct_set_index(res, i, dc_val); + } + + /* If self is tracked, then copy is tracked */ + if (MS_OBJECT_IS_GC(self) && MS_IS_TRACKED(self)) + PyObject_GC_Track(res); + + return res; + +error: + Py_DECREF(res); + return NULL; +} + + static PyObject * Struct_replace( PyObject *self, @@ -8002,6 +8063,8 @@ Struct_replace( } } + if (Struct_post_init(struct_type, out) < 0) goto error; + if (is_gc && !should_untrack) { PyObject_GC_Track(out); } @@ -8358,6 +8421,7 @@ StructMixin_config(StructMetaObject *self, void *closure) { static PyMethodDef Struct_methods[] = { {"__copy__", Struct_copy, METH_NOARGS, "copy a struct"}, + {"__deepcopy__", Struct_deepcopy, METH_VARARGS, "deepcopy a struct"}, {"__replace__", (PyCFunction) Struct_replace, METH_FASTCALL | METH_KEYWORDS, "create a new struct with replacements" }, {"__reduce__", Struct_reduce, METH_NOARGS, "reduce a struct"}, {"__rich_repr__", Struct_rich_repr, METH_NOARGS, "rich repr"}, @@ -22306,6 +22370,7 @@ msgspec_clear(PyObject *m) #endif Py_CLEAR(st->astimezone); Py_CLEAR(st->re_compile); + Py_CLEAR(st->copy_deepcopy); return 0; } @@ -22380,6 +22445,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg) #endif Py_VISIT(st->astimezone); Py_VISIT(st->re_compile); + Py_VISIT(st->copy_deepcopy); return 0; } @@ -22674,6 +22740,9 @@ PyInit__core(void) Py_DECREF(temp_module); if (st->re_compile == NULL) return NULL; + // cache for 'copy.deepcopy'. to access this function, use 'get_deepcopy_func' + st->copy_deepcopy = NULL; + /* Initialize cached constant strings */ #define CACHED_STRING(attr, str) \ if ((st->attr = PyUnicode_InternFromString(str)) == NULL) return NULL diff --git a/tests/unit/test_struct.py b/tests/unit/test_struct.py index 42f5306c..fb839473 100644 --- a/tests/unit/test_struct.py +++ b/tests/unit/test_struct.py @@ -810,12 +810,65 @@ class Test(Struct): b: int a: int - x = copy.copy(Test(1, 2)) + o = Test(1, 2) + x = copy.copy(o) assert type(x) is Test + assert x is not o assert x.b == 1 assert x.a == 2 +def test_struct_deepcopy(): + o = Struct() + x = copy.deepcopy(Struct()) + assert type(x) is Struct + assert x is not o + + class Sub(Struct): + one: str + two: list[int] + + class Test(Struct): + a: int + b: int + c: list[str] + sub: Sub + + o = Test( + a=1, + b=2, + c=["1", "2"], + sub=Sub(one="hello", two=[3]), + ) + x = copy.deepcopy(o) + assert type(x) is Test + assert x.a == 1 + assert x.b == 2 + assert x.c == ["1", "2"] + assert x.c is not o.c + assert x.sub is not o.sub + assert x.sub.one == "hello" + assert x.sub.two == [3] + assert x.sub.two is not o.sub.two + + +def test_struct_deepcopy_custom_impl(): + # ensure we respect custom __deepcopy__ methods + class CustomThing: + def __init__(self, value): + self.value = value + + def __deepcopy__(self, memo): + return CustomThing(value=self.value + 1) + + class TestWithCustom(Struct): + custom: CustomThing + + t = TestWithCustom(CustomThing(1)) + tc = copy.deepcopy(t) + assert tc.custom.value == 2 + + class FrozenPoint(Struct, frozen=True): x: int y: int @@ -2664,7 +2717,7 @@ def __post_init__(self): assert x1 == x2 assert count == 1 - def test_post_init_not_called_on_replace(self): + def test_post_init_not_called_on_deepcopy(self): count = 0 class Ex(Struct): @@ -2674,6 +2727,20 @@ def __post_init__(self): x1 = Ex() assert count == 1 - x2 = msgspec.structs.replace(x1) + x2 = copy.deepcopy(x1) assert x1 == x2 assert count == 1 + + def test_post_init_called_on_replace(self, replace): + count = 0 + + class Ex(Struct): + def __post_init__(self): + nonlocal count + count += 1 + + x1 = Ex() + assert count == 1 + x2 = replace(x1) + assert x1 == x2 + assert count == 2