Skip to content

Commit 3eafa3f

Browse files
mtsokolngoldbaum
andcommitted
ENH: Add a __dict__ to ufunc objects and use it to allow overriding __doc__
Co-authored-by: Nathan Goldbaum <[email protected]>
1 parent 20d051a commit 3eafa3f

File tree

6 files changed

+71
-4
lines changed

6 files changed

+71
-4
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
* UFuncs new support `__dict__` attribute and allow overriding
2+
`__doc__` (either directly or via `ufunc.__dict__["__doc__"]`).

numpy/_core/include/numpy/ufuncobject.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,10 @@ typedef struct _tagPyUFuncObject {
170170
* with the dtypes for the inputs and outputs.
171171
*/
172172
PyUFunc_TypeResolutionFunc *type_resolver;
173-
/* Was the legacy loop resolver */
174-
void *reserved2;
173+
174+
/* A dictionary to monkeypatch ufuncs */
175+
PyObject *dict;
176+
175177
/*
176178
* This was blocked off to be the "new" inner loop selector in 1.7,
177179
* but this was never implemented. (This is also why the above

numpy/_core/src/multiarray/npy_static_data.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ intern_strings(void)
6363
INTERN_STRING(__dlpack__, "__dlpack__");
6464
INTERN_STRING(pyvals_name, "UFUNC_PYVALS_NAME");
6565
INTERN_STRING(legacy, "legacy");
66+
INTERN_STRING(__doc__, "__doc__");
6667
return 0;
6768
}
6869

numpy/_core/src/multiarray/npy_static_data.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ typedef struct npy_interned_str_struct {
3838
PyObject *__dlpack__;
3939
PyObject *pyvals_name;
4040
PyObject *legacy;
41+
PyObject *__doc__;
4142
} npy_interned_str_struct;
4243

4344
/*

numpy/_core/src/umath/ufunc_object.c

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4771,6 +4771,7 @@ PyUFunc_FromFuncAndDataAndSignatureAndIdentity(PyUFuncGenericFunction *func, voi
47714771
return NULL;
47724772
}
47734773
}
4774+
ufunc->dict = PyDict_New();
47744775
/*
47754776
* TODO: I tried adding a default promoter here (either all object for
47764777
* some special cases, or all homogeneous). Those are reasonable
@@ -6411,6 +6412,15 @@ ufunc_get_doc(PyUFuncObject *ufunc, void *NPY_UNUSED(ignored))
64116412
{
64126413
PyObject *doc;
64136414

6415+
// If there is a __doc__ in the instance __dict__, use it.
6416+
int result = PyDict_GetItemRef(ufunc->dict, npy_interned_str.__doc__, &doc);
6417+
if (result == -1) {
6418+
return NULL;
6419+
}
6420+
else if (result == 1) {
6421+
return doc;
6422+
}
6423+
64146424
if (npy_cache_import_runtime(
64156425
"numpy._core._internal", "_ufunc_doc_signature_formatter",
64166426
&npy_runtime_imports._ufunc_doc_signature_formatter) == -1) {
@@ -6434,6 +6444,20 @@ ufunc_get_doc(PyUFuncObject *ufunc, void *NPY_UNUSED(ignored))
64346444
return doc;
64356445
}
64366446

6447+
static int
6448+
ufunc_set_doc(PyUFuncObject *ufunc, PyObject *doc, void *NPY_UNUSED(ignored))
6449+
{
6450+
if (doc == NULL) {
6451+
int result = PyDict_Contains(ufunc->dict, npy_interned_str.__doc__);
6452+
if (result == 1) {
6453+
return PyDict_DelItem(ufunc->dict, npy_interned_str.__doc__);
6454+
} else {
6455+
return result;
6456+
}
6457+
} else {
6458+
return PyDict_SetItem(ufunc->dict, npy_interned_str.__doc__, doc);
6459+
}
6460+
}
64376461

64386462
static PyObject *
64396463
ufunc_get_nin(PyUFuncObject *ufunc, void *NPY_UNUSED(ignored))
@@ -6519,8 +6543,8 @@ ufunc_get_signature(PyUFuncObject *ufunc, void *NPY_UNUSED(ignored))
65196543

65206544
static PyGetSetDef ufunc_getset[] = {
65216545
{"__doc__",
6522-
(getter)ufunc_get_doc,
6523-
NULL, NULL, NULL},
6546+
(getter)ufunc_get_doc, (setter)ufunc_set_doc,
6547+
NULL, NULL},
65246548
{"nin",
65256549
(getter)ufunc_get_nin,
65266550
NULL, NULL, NULL},
@@ -6549,6 +6573,17 @@ static PyGetSetDef ufunc_getset[] = {
65496573
};
65506574

65516575

6576+
/******************************************************************************
6577+
*** UFUNC MEMBERS ***
6578+
*****************************************************************************/
6579+
6580+
static PyMemberDef ufunc_members[] = {
6581+
{"__dict__", T_OBJECT, offsetof(PyUFuncObject, dict),
6582+
READONLY},
6583+
{NULL},
6584+
};
6585+
6586+
65526587
/******************************************************************************
65536588
*** UFUNC TYPE OBJECT ***
65546589
*****************************************************************************/
@@ -6568,6 +6603,12 @@ NPY_NO_EXPORT PyTypeObject PyUFunc_Type = {
65686603
.tp_traverse = (traverseproc)ufunc_traverse,
65696604
.tp_methods = ufunc_methods,
65706605
.tp_getset = ufunc_getset,
6606+
.tp_getattro = PyObject_GenericGetAttr,
6607+
.tp_setattro = PyObject_GenericSetAttr,
6608+
// TODO when Python 3.12 is the minimum supported version,
6609+
// use Py_TPFLAGS_MANAGED_DICT
6610+
.tp_members = ufunc_members,
6611+
.tp_dictoffset = offsetof(PyUFuncObject, dict),
65716612
};
65726613

65736614
/* End of code for ufunc objects */

numpy/_core/tests/test_umath.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4016,6 +4016,26 @@ def test_array_ufunc_direct_call(self):
40164016
res = a.__array_ufunc__(np.add, "__call__", a, a)
40174017
assert_array_equal(res, a + a)
40184018

4019+
def test_ufunc_docstring(self):
4020+
original_doc = np.add.__doc__
4021+
new_doc = "new docs"
4022+
4023+
np.add.__doc__ = new_doc
4024+
assert np.add.__doc__ == new_doc
4025+
assert np.add.__dict__["__doc__"] == new_doc
4026+
4027+
del np.add.__doc__
4028+
assert np.add.__doc__ == original_doc
4029+
assert np.add.__dict__ == {}
4030+
4031+
np.add.__dict__["other"] = 1
4032+
np.add.__dict__["__doc__"] = new_doc
4033+
assert np.add.__doc__ == new_doc
4034+
4035+
del np.add.__dict__["__doc__"]
4036+
assert np.add.__doc__ == original_doc
4037+
4038+
40194039
class TestChoose:
40204040
def test_mixed(self):
40214041
c = np.array([True, True])

0 commit comments

Comments
 (0)