Skip to content

Commit 28111d6

Browse files
committed
Closes #5219: improve astype in ak.pandas.extension
1 parent 03aa4df commit 28111d6

File tree

9 files changed

+615
-58
lines changed

9 files changed

+615
-58
lines changed

arkouda/numpy/_typing/_typing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
None,
2929
]
3030

31-
StringDTypeTypes: TypeAlias = _Union[Literal["str", "str_"], type[str_], type[str], type[Strings]]
31+
StringDTypeTypes: TypeAlias = _Union[
32+
Literal["str", "str_", "string"], type[str_], type[str], type[Strings]
33+
]
3234

3335
_ArrayLikeNum: TypeAlias = _Union[
3436
np.ndarray, # keeps it simple; or list your NDArray[...]
@@ -88,4 +90,9 @@
8890

8991
def is_string_dtype_hint(x: object) -> TypeGuard["_StringDType"]:
9092
# accept the spellings you want to map to Arkouda Strings
91-
return x in ("str", "str_") or x is str_ or x is str_ or x is Strings
93+
return (
94+
(isinstance(x, str) and x.lower() in ("str", "str_", "string", "trings"))
95+
or x is str_
96+
or x is str
97+
or x is Strings
98+
)

arkouda/numpy/dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def dtype(x):
258258
return bigint()
259259

260260
# ---- String dtype spellings ----
261-
if isinstance(x, str) and x.lower() in {"str", "str_", "Strings", "strings"}:
261+
if isinstance(x, str) and x.lower() in {"str", "str_", "strings", "string"}:
262262
return np.dtype(np.str_)
263263
if x in (str, np.str_):
264264
return np.dtype(np.str_)

arkouda/pandas/extension/_arkouda_array.py

Lines changed: 100 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Sequence, TypeVar
3+
from typing import TYPE_CHECKING, Any, Sequence, TypeVar, Union, overload
44
from typing import cast as type_cast
55

66
import numpy as np
77

88
from numpy import ndarray
99
from numpy.typing import NDArray
1010
from pandas.api.extensions import ExtensionArray
11-
12-
from arkouda.numpy.dtypes import dtype as ak_dtype
11+
from pandas.core.dtypes.dtypes import ExtensionDtype
1312

1413
from ._arkouda_extension_array import ArkoudaExtensionArray
1514
from ._dtypes import (
@@ -167,27 +166,109 @@ def __setitem__(self, key, value):
167166

168167
self._data[key] = value
169168

170-
def astype(self, dtype, copy: bool = False):
171-
# Always hand back a real object-dtype ndarray when object is requested
172-
if dtype in (object, np.object_, "object", np.dtype("O")):
173-
return self.to_ndarray().astype(object, copy=copy)
169+
@overload
170+
def astype(self, dtype: np.dtype[Any], copy: bool = True) -> NDArray[Any]: ...
174171

175-
if isinstance(dtype, _ArkoudaBaseDtype):
176-
dtype = dtype.numpy_dtype
172+
@overload
173+
def astype(self, dtype: ExtensionDtype, copy: bool = True) -> ExtensionArray: ...
177174

178-
# Server-side cast for numeric/bool
179-
try:
180-
npdt = np.dtype(dtype)
181-
except Exception:
182-
return self.to_ndarray().astype(dtype, copy=copy)
175+
@overload
176+
def astype(self, dtype: Any, copy: bool = True) -> Union[ExtensionArray, NDArray[Any]]: ...
177+
178+
def astype(
179+
self,
180+
dtype: Any,
181+
copy: bool = True,
182+
) -> Union[ExtensionArray, NDArray[Any]]:
183+
"""
184+
Cast the array to a specified dtype.
185+
186+
Casting rules:
187+
188+
* If ``dtype`` requests ``object``, returns a NumPy ``NDArray[Any]`` of
189+
dtype ``object`` containing the array values.
190+
* Otherwise, the target dtype is normalized using Arkouda's dtype
191+
resolution rules.
192+
* If the normalized dtype matches the current dtype and ``copy=False``,
193+
returns ``self``.
194+
* In all other cases, casts the underlying Arkouda array to the target
195+
dtype and returns an Arkouda-backed ``ArkoudaExtensionArray``.
196+
197+
Parameters
198+
----------
199+
dtype : Any
200+
Target dtype. May be a NumPy dtype, pandas dtype, Arkouda dtype,
201+
or any dtype-like object accepted by Arkouda.
202+
copy : bool
203+
Whether to force a copy when the target dtype matches the current dtype.
204+
Default is True.
205+
206+
Returns
207+
-------
208+
Union[ExtensionArray, NDArray[Any]]
209+
The cast result. Returns a NumPy array only when casting to ``object``;
210+
otherwise returns an Arkouda-backed ExtensionArray.
211+
212+
Examples
213+
--------
214+
Basic numeric casting returns an Arkouda-backed array:
215+
216+
>>> import arkouda as ak
217+
>>> from arkouda.pandas.extension import ArkoudaArray
218+
>>> a = ArkoudaArray(ak.array([1, 2, 3], dtype="int64"))
219+
>>> a.astype("float64").to_ndarray()
220+
array([1., 2., 3.])
221+
222+
Casting to the same dtype with ``copy=False`` returns the original object:
223+
224+
>>> b = a.astype("int64", copy=False)
225+
>>> b is a
226+
True
227+
228+
Forcing a copy when the dtype is unchanged returns a new array:
229+
230+
>>> c = a.astype("int64", copy=True)
231+
>>> c is a
232+
False
233+
>>> c.to_ndarray()
234+
array([1, 2, 3])
235+
236+
Casting to ``object`` materializes the data to a NumPy array:
237+
238+
>>> a.astype(object)
239+
array([1, 2, 3], dtype=object)
240+
241+
NumPy and pandas dtype objects are also accepted:
242+
243+
>>> import numpy as np
244+
>>> a.astype(np.dtype("bool")).to_ndarray()
245+
array([ True, True, True])
246+
"""
247+
from arkouda.numpy.dtypes import dtype as ak_dtype
248+
249+
# --- 1) ExtensionDtype branch (satisfies overload #2) ---
250+
if isinstance(dtype, ExtensionDtype):
251+
# pandas extension dtypes typically have .numpy_dtype
252+
if hasattr(dtype, "numpy_dtype"):
253+
dtype = dtype.numpy_dtype
254+
255+
if copy is False and self.dtype.numpy_dtype == dtype:
256+
return self
257+
258+
casted = self._data.astype(dtype)
259+
return type_cast(ExtensionArray, ArkoudaExtensionArray._from_sequence(casted))
260+
261+
# --- 2) object -> numpy (satisfies overload #1 / general) ---
262+
if dtype in (object, np.object_, "object", np.dtype("O")):
263+
return self.to_ndarray().astype(object, copy=False)
183264

184-
from arkouda.numpy.numeric import cast as ak_cast
265+
dtype = ak_dtype(dtype)
185266

186-
if npdt.kind in {"i", "u", "f", "b"}:
187-
return type(self)(ak_cast(self._data, ak_dtype(npdt.name)))
267+
if copy is False and self.dtype.numpy_dtype == dtype:
268+
return self
188269

189-
# Fallback: local cast
190-
return self.to_ndarray().astype(npdt, copy=copy)
270+
casted = self._data.astype(dtype)
271+
return ArkoudaExtensionArray._from_sequence(casted)
191272

192273
def isna(self) -> NDArray[np.bool_]:
193274
from arkouda.numpy import isnan

arkouda/pandas/extension/_arkouda_categorical_array.py

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Sequence, TypeVar
3+
from typing import TYPE_CHECKING, Any, Sequence, TypeVar, Union, overload
4+
from typing import cast as type_cast
45

5-
import numpy as np # new
6+
import numpy as np
67

78
from numpy import ndarray
9+
from numpy.typing import NDArray
10+
from pandas import CategoricalDtype as pd_CategoricalDtype
11+
from pandas import StringDtype as pd_StringDtype
812
from pandas.api.extensions import ExtensionArray
13+
from pandas.core.dtypes.dtypes import ExtensionDtype
914

1015
import arkouda as ak
1116

1217
from ._arkouda_array import ArkoudaArray
1318
from ._arkouda_extension_array import ArkoudaExtensionArray
19+
from ._arkouda_string_array import ArkoudaStringArray
1420
from ._dtypes import ArkoudaCategoricalDtype
1521

1622

@@ -84,8 +90,125 @@ def __getitem__(self, idx):
8490
return self._data[idx]
8591
return ArkoudaCategoricalArray(self._data[idx])
8692

87-
def astype(self, x, dtype):
88-
raise NotImplementedError("array_api.astype is not implemented in Arkouda yet")
93+
@overload
94+
def astype(self, dtype: np.dtype[Any], copy: bool = True) -> NDArray[Any]: ...
95+
96+
@overload
97+
def astype(self, dtype: ExtensionDtype, copy: bool = True) -> ExtensionArray: ...
98+
99+
@overload
100+
def astype(self, dtype: Any, copy: bool = True) -> Union[ExtensionArray, NDArray[Any]]: ...
101+
102+
def astype(
103+
self,
104+
dtype: Any,
105+
copy: bool = True,
106+
) -> Union[ExtensionArray, NDArray[Any]]:
107+
"""
108+
Cast to a specified dtype.
109+
110+
* If ``dtype`` is categorical (pandas ``category`` / ``CategoricalDtype`` /
111+
``ArkoudaCategoricalDtype``), returns an Arkouda-backed
112+
``ArkoudaCategoricalArray`` (optionally copied).
113+
* If ``dtype`` requests ``object``, returns a NumPy ``ndarray`` of dtype object
114+
containing the category labels (materialized to the client).
115+
* If ``dtype`` requests a string dtype, returns an Arkouda-backed
116+
``ArkoudaStringArray`` containing the labels as strings.
117+
* Otherwise, casts the labels (as strings) to the requested dtype and returns an
118+
Arkouda-backed ExtensionArray.
119+
120+
Parameters
121+
----------
122+
dtype : Any
123+
Target dtype.
124+
copy : bool
125+
Whether to force a copy when possible. If categorical-to-categorical and
126+
``copy=True``, attempts to copy the underlying Arkouda ``Categorical`` (if
127+
supported). Default is True.
128+
129+
Returns
130+
-------
131+
Union[ExtensionArray, NDArray[Any]]
132+
The cast result. Returns a NumPy array only when casting to ``object``;
133+
otherwise returns an Arkouda-backed ExtensionArray.
134+
135+
Examples
136+
--------
137+
Casting to ``category`` returns an Arkouda-backed categorical array:
138+
139+
>>> import arkouda as ak
140+
>>> from arkouda.pandas.extension import ArkoudaCategoricalArray
141+
>>> c = ArkoudaCategoricalArray(ak.Categorical(ak.array(["x", "y", "x"])))
142+
>>> out = c.astype("category")
143+
>>> out is c
144+
False
145+
146+
Forcing a copy when casting to the same categorical dtype returns a new array:
147+
148+
>>> out2 = c.astype("category", copy=True)
149+
>>> out2 is c
150+
False
151+
>>> out2.to_ndarray()
152+
array(['x', 'y', 'x'], dtype='<U...')
153+
154+
Casting to ``object`` materializes the category labels to a NumPy object array:
155+
156+
>>> c.astype(object)
157+
array(['x', 'y', 'x'], dtype=object)
158+
159+
Casting to a string dtype returns an Arkouda-backed string array of labels:
160+
161+
>>> s = c.astype("string")
162+
>>> s.to_ndarray()
163+
array(['x', 'y', 'x'], dtype='<U1')
164+
165+
Casting to another dtype casts the labels-as-strings and returns an Arkouda-backed array:
166+
167+
>>> c_num = ArkoudaCategoricalArray(ak.Categorical(ak.array(["1", "2", "3"])))
168+
>>> a = c_num.astype("int64")
169+
>>> a.to_ndarray()
170+
array([1, 2, 3])
171+
"""
172+
from arkouda.numpy._typing._typing import is_string_dtype_hint
173+
174+
# --- 1) ExtensionDtype branch first: proves overload #2 returns ExtensionArray ---
175+
if isinstance(dtype, ExtensionDtype):
176+
if hasattr(dtype, "numpy_dtype"):
177+
dtype = dtype.numpy_dtype
178+
179+
if isinstance(dtype, (ArkoudaCategoricalDtype, pd_CategoricalDtype)) or dtype in (
180+
"category",
181+
):
182+
if not copy:
183+
return self
184+
data = self._data.copy() if hasattr(self._data, "copy") else self._data
185+
return type_cast(ExtensionArray, type(self)(data))
186+
187+
data = self._data.to_strings()
188+
189+
if isinstance(dtype, pd_StringDtype) or is_string_dtype_hint(dtype):
190+
return type_cast(ExtensionArray, ArkoudaStringArray._from_sequence(data))
191+
192+
casted = data.astype(dtype)
193+
return type_cast(ExtensionArray, ArkoudaExtensionArray._from_sequence(casted))
194+
195+
# --- 2) object -> numpy ---
196+
if dtype in (object, np.object_, "object", np.dtype("O")):
197+
return self.to_ndarray().astype(object, copy=False)
198+
199+
if isinstance(dtype, (ArkoudaCategoricalDtype, pd_CategoricalDtype)) or dtype in ("category",):
200+
if not copy:
201+
return self
202+
data = self._data.copy() if hasattr(self._data, "copy") else self._data
203+
return type(self)(data)
204+
205+
data = self._data.to_strings()
206+
207+
if isinstance(dtype, pd_StringDtype) or is_string_dtype_hint(dtype):
208+
return ArkoudaStringArray._from_sequence(data)
209+
210+
casted = data.astype(dtype)
211+
return ArkoudaExtensionArray._from_sequence(casted)
89212

90213
def isna(self):
91214
return ak.zeros(self._data.size, dtype=ak.bool)

arkouda/pandas/extension/_arkouda_extension_array.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,7 @@ def _from_sequence(
332332
from arkouda.numpy.pdarraycreation import array as ak_array
333333
from arkouda.numpy.strings import Strings
334334
from arkouda.pandas.categorical import Categorical
335-
from arkouda.pandas.extension._arkouda_array import ArkoudaArray
336-
from arkouda.pandas.extension._arkouda_categorical_array import ArkoudaCategoricalArray
337-
from arkouda.pandas.extension._arkouda_string_array import ArkoudaStringArray
335+
from arkouda.pandas.extension import ArkoudaArray, ArkoudaCategoricalArray, ArkoudaStringArray
338336

339337
# Fast path: already an Arkouda column. Pick the matching subclass.
340338
if isinstance(scalars, pdarray):

0 commit comments

Comments
 (0)