diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index e1a2a0142c52e..a5cc78d69e5f7 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -107,6 +107,8 @@ from pandas._libs.tslibs.period cimport is_period_object from pandas._libs.tslibs.timedeltas cimport convert_to_timedelta64 from pandas._libs.tslibs.timezones cimport tz_compare +from pandas.core.dtypes.base import _registry + # constants that will be compared to potentially arbitrarily large # python int cdef: @@ -1693,6 +1695,11 @@ def infer_dtype(value: object, skipna: bool = True) -> str: if is_interval_array(values): return "interval" + print("infer_dtype") + reg_dtype = _registry.match_scalar(val) + if reg_dtype: + return str(reg_dtype) + cnp.PyArray_ITER_RESET(it) for i in range(n): val = PyArray_GETITEM(values, PyArray_ITER_DATA(it)) diff --git a/pandas/core/construction.py b/pandas/core/construction.py index 665eb75953078..06e81264cb2c3 100644 --- a/pandas/core/construction.py +++ b/pandas/core/construction.py @@ -370,6 +370,10 @@ def array( elif data.dtype.kind == "b": return BooleanArray._from_sequence(data, dtype="boolean", copy=copy) + # elif inferred_dtype != "mixed": + # dtype = pandas_dtype(inferred_dtype) + # cls = dtype.construct_array_type() + # return cls._from_sequence(data, dtype=dtype, copy=copy) else: # e.g. complex return NumpyExtensionArray._from_sequence(data, dtype=data.dtype, copy=copy) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index d8a42d83b6c54..85db5817190bb 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -444,6 +444,13 @@ def _can_fast_transpose(self) -> bool: """ return False + def is_unambiguous_scalar(self): + return False + + @classmethod + def construct_from_scalar(cls, scalar): + return cls() + class StorageExtensionDtype(ExtensionDtype): """ExtensionDtype that may be backed by more than one implementation.""" @@ -582,5 +589,13 @@ def find( return None + def match_scalar( + self, scalar: Any + ) -> type_t[ExtensionDtype] | ExtensionDtype | None: + for dtype in self.dtypes: + if dtype.is_unambiguous_scalar(scalar): + return dtype.construct_from_scalar(scalar) + return None + _registry = Registry() diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 6ba07b1761557..2f926b64ed462 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -44,6 +44,7 @@ LossySetitemError, ) +from pandas.core.dtypes.base import _registry from pandas.core.dtypes.common import ( ensure_int8, ensure_int16, @@ -857,6 +858,10 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]: subtype = infer_dtype_from_scalar(val.left)[0] dtype = IntervalDtype(subtype=subtype, closed=val.closed) + reg_dtype = _registry.match_scalar(val) + if reg_dtype: + dtype = reg_dtype + return dtype, val @@ -913,6 +918,10 @@ def infer_dtype_from_array(arr) -> tuple[DtypeObj, ArrayLike]: inferred = lib.infer_dtype(arr, skipna=False) if inferred in ["string", "bytes", "mixed", "mixed-integer"]: return (np.dtype(np.object_), arr) + else: + arr_dtype = pandas_dtype_func(inferred) + if isinstance(arr_dtype, ExtensionDtype): + return arr_dtype, arr arr = np.asarray(arr) return arr.dtype, arr diff --git a/pandas/core/series.py b/pandas/core/series.py index 4f79e30f48f3c..ca64e986f702d 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -501,6 +501,8 @@ def __init__( elif copy: data = data.copy() else: + if dtype is None: + dtype = infer_dtype_from(data)[0] data = sanitize_array(data, index, dtype, copy) data = SingleBlockManager.from_array(data, index, refs=refs)