Skip to content

Commit ccb00b2

Browse files
authored
Merge pull request numpy#27434 from ngoldbaum/fix-dtype-new-dtype
ENH: make np.dtype(scalar_type) return the default dtype instance
2 parents f1b2fb5 + fad8a23 commit ccb00b2

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

numpy/_core/src/multiarray/array_coercion.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,17 @@ npy_discover_dtype_from_pytype(PyTypeObject *pytype)
247247
}
248248

249249
/*
250-
* Note: This function never fails, but will return `NULL` for unknown scalars
251-
* and `None` for known array-likes (e.g. tuple, list, ndarray).
250+
* Note: This function never fails, but will return `NULL` for unknown scalars or
251+
* known array-likes (e.g. tuple, list, ndarray).
252252
*/
253253
NPY_NO_EXPORT PyObject *
254254
PyArray_DiscoverDTypeFromScalarType(PyTypeObject *pytype)
255255
{
256-
return (PyObject *)npy_discover_dtype_from_pytype(pytype);
256+
PyObject *DType = (PyObject *)npy_discover_dtype_from_pytype(pytype);
257+
if (DType == NULL || DType == Py_None) {
258+
return NULL;
259+
}
260+
return DType;
257261
}
258262

259263

numpy/_core/src/multiarray/descriptor.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "npy_buffer.h"
3030
#include "dtypemeta.h"
3131
#include "stringdtype/dtype.h"
32+
#include "array_coercion.h"
3233

3334
#ifndef PyDictProxy_Check
3435
#define PyDictProxy_Check(obj) (Py_TYPE(obj) == &PyDictProxy_Type)
@@ -1600,6 +1601,10 @@ _convert_from_type(PyObject *obj) {
16001601
return PyArray_DescrFromType(NPY_OBJECT);
16011602
}
16021603
else {
1604+
PyObject *DType = PyArray_DiscoverDTypeFromScalarType(typ);
1605+
if (DType != NULL) {
1606+
return PyArray_GetDefaultDescr((PyArray_DTypeMeta *)DType);
1607+
}
16031608
PyArray_Descr *ret = _try_convert_from_dtype_attr(obj);
16041609
if ((PyObject *)ret != Py_NotImplemented) {
16051610
return ret;

numpy/_core/src/multiarray/scalarapi.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,12 @@ PyArray_DescrFromTypeObject(PyObject *type)
390390
Py_INCREF(type);
391391
return (PyArray_Descr *)new;
392392
}
393+
394+
PyObject *DType = PyArray_DiscoverDTypeFromScalarType((PyTypeObject *)type);
395+
if (DType != NULL) {
396+
return PyArray_GetDefaultDescr((PyArray_DTypeMeta *)DType);
397+
}
398+
393399
return _descr_from_subtype(type);
394400
}
395401

0 commit comments

Comments
 (0)