Skip to content

Commit b00ce10

Browse files
committed
infer
1 parent 078e11f commit b00ce10

File tree

5 files changed

+37
-0
lines changed

5 files changed

+37
-0
lines changed

pandas/_libs/lib.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ from pandas._libs.tslibs.period cimport is_period_object
107107
from pandas._libs.tslibs.timedeltas cimport convert_to_timedelta64
108108
from pandas._libs.tslibs.timezones cimport tz_compare
109109

110+
from pandas.core.dtypes.base import _registry
111+
110112
# constants that will be compared to potentially arbitrarily large
111113
# python int
112114
cdef:
@@ -1693,6 +1695,11 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
16931695
if is_interval_array(values):
16941696
return "interval"
16951697

1698+
print("infer_dtype")
1699+
reg_dtype = _registry.match_scalar(val)
1700+
if reg_dtype:
1701+
return str(reg_dtype)
1702+
16961703
cnp.PyArray_ITER_RESET(it)
16971704
for i in range(n):
16981705
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))

pandas/core/construction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,10 @@ def array(
370370

371371
elif data.dtype.kind == "b":
372372
return BooleanArray._from_sequence(data, dtype="boolean", copy=copy)
373+
# elif inferred_dtype != "mixed":
374+
# dtype = pandas_dtype(inferred_dtype)
375+
# cls = dtype.construct_array_type()
376+
# return cls._from_sequence(data, dtype=dtype, copy=copy)
373377
else:
374378
# e.g. complex
375379
return NumpyExtensionArray._from_sequence(data, dtype=data.dtype, copy=copy)

pandas/core/dtypes/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,13 @@ def _can_fast_transpose(self) -> bool:
444444
"""
445445
return False
446446

447+
def is_unambiguous_scalar(self):
448+
return False
449+
450+
@classmethod
451+
def construct_from_scalar(cls, scalar):
452+
return cls()
453+
447454

448455
class StorageExtensionDtype(ExtensionDtype):
449456
"""ExtensionDtype that may be backed by more than one implementation."""
@@ -582,5 +589,13 @@ def find(
582589

583590
return None
584591

592+
def match_scalar(
593+
self, scalar: Any
594+
) -> type_t[ExtensionDtype] | ExtensionDtype | None:
595+
for dtype in self.dtypes:
596+
if dtype.is_unambiguous_scalar(scalar):
597+
return dtype.construct_from_scalar(scalar)
598+
return None
599+
585600

586601
_registry = Registry()

pandas/core/dtypes/cast.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
LossySetitemError,
4545
)
4646

47+
from pandas.core.dtypes.base import _registry
4748
from pandas.core.dtypes.common import (
4849
ensure_int8,
4950
ensure_int16,
@@ -857,6 +858,10 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]:
857858
subtype = infer_dtype_from_scalar(val.left)[0]
858859
dtype = IntervalDtype(subtype=subtype, closed=val.closed)
859860

861+
reg_dtype = _registry.match_scalar(val)
862+
if reg_dtype:
863+
dtype = reg_dtype
864+
860865
return dtype, val
861866

862867

@@ -913,6 +918,10 @@ def infer_dtype_from_array(arr) -> tuple[DtypeObj, ArrayLike]:
913918
inferred = lib.infer_dtype(arr, skipna=False)
914919
if inferred in ["string", "bytes", "mixed", "mixed-integer"]:
915920
return (np.dtype(np.object_), arr)
921+
else:
922+
arr_dtype = pandas_dtype_func(inferred)
923+
if isinstance(arr_dtype, ExtensionDtype):
924+
return arr_dtype, arr
916925

917926
arr = np.asarray(arr)
918927
return arr.dtype, arr

pandas/core/series.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ def __init__(
501501
elif copy:
502502
data = data.copy()
503503
else:
504+
if dtype is None:
505+
dtype = infer_dtype_from(data)[0]
504506
data = sanitize_array(data, index, dtype, copy)
505507
data = SingleBlockManager.from_array(data, index, refs=refs)
506508

0 commit comments

Comments
 (0)