Skip to content

Commit 36fa231

Browse files
committed
ENH: make np.dtype(scalar_type) return the default dtype instance for new dtypes
1 parent f1b2fb5 commit 36fa231

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

numpy/_core/src/multiarray/descriptor.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "npy_buffer.h"
3030
#include "dtypemeta.h"
3131
#include "stringdtype/dtype.h"
32+
#include "array_coercion.h"
3233

3334
#ifndef PyDictProxy_Check
3435
#define PyDictProxy_Check(obj) (Py_TYPE(obj) == &PyDictProxy_Type)
@@ -1556,6 +1557,14 @@ PyArray_GetDefaultDescr(PyArray_DTypeMeta *DType)
15561557
return NPY_DT_CALL_default_descr(DType);
15571558
}
15581559

1560+
NPY_NO_EXPORT PyArray_Descr *
1561+
default_descr_from_scalar_type(PyTypeObject *typ) {
1562+
PyObject *DType = PyArray_DiscoverDTypeFromScalarType(typ);
1563+
if (DType == NULL || DType == Py_None) {
1564+
return NULL;
1565+
}
1566+
return PyArray_GetDefaultDescr((PyArray_DTypeMeta *)DType);
1567+
}
15591568

15601569
/**
15611570
* Get a dtype instance from a python type
@@ -1600,6 +1609,10 @@ _convert_from_type(PyObject *obj) {
16001609
return PyArray_DescrFromType(NPY_OBJECT);
16011610
}
16021611
else {
1612+
PyArray_Descr *descr = default_descr_from_scalar_type(typ);
1613+
if (descr != NULL) {
1614+
return descr;
1615+
}
16031616
PyArray_Descr *ret = _try_convert_from_dtype_attr(obj);
16041617
if ((PyObject *)ret != Py_NotImplemented) {
16051618
return ret;

numpy/_core/src/multiarray/descriptor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,7 @@ arraydescr_field_subset_view(_PyArray_LegacyDescr *self, PyObject *ind);
6565

6666
extern NPY_NO_EXPORT char const *_datetime_strings[];
6767

68+
NPY_NO_EXPORT PyArray_Descr *
69+
default_descr_from_scalar_type(PyTypeObject *typ);
70+
6871
#endif /* NUMPY_CORE_SRC_MULTIARRAY_DESCRIPTOR_H_ */

numpy/_core/src/multiarray/scalarapi.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,10 @@ PyArray_DescrFromTypeObject(PyObject *type)
390390
Py_INCREF(type);
391391
return (PyArray_Descr *)new;
392392
}
393-
return _descr_from_subtype(type);
393+
394+
PyArray_Descr *default_descr = default_descr_from_scalar_type((PyTypeObject *)type);
395+
396+
return default_descr != NULL ? default_descr : _descr_from_subtype(type);
394397
}
395398

396399

0 commit comments

Comments
 (0)