|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import TYPE_CHECKING, Any, Sequence, TypeVar |
| 3 | +from typing import TYPE_CHECKING, Any, Sequence, TypeVar, Union, overload |
| 4 | +from typing import cast as type_cast |
4 | 5 |
|
5 | | -import numpy as np # new |
| 6 | +import numpy as np |
6 | 7 |
|
7 | 8 | from numpy import ndarray |
| 9 | +from numpy.typing import NDArray |
| 10 | +from pandas import CategoricalDtype as pd_CategoricalDtype |
| 11 | +from pandas import StringDtype as pd_StringDtype |
8 | 12 | from pandas.api.extensions import ExtensionArray |
| 13 | +from pandas.core.dtypes.dtypes import ExtensionDtype |
9 | 14 |
|
10 | 15 | import arkouda as ak |
11 | 16 |
|
12 | 17 | from ._arkouda_array import ArkoudaArray |
13 | 18 | from ._arkouda_extension_array import ArkoudaExtensionArray |
| 19 | +from ._arkouda_string_array import ArkoudaStringArray |
14 | 20 | from ._dtypes import ArkoudaCategoricalDtype |
15 | 21 |
|
16 | 22 |
|
@@ -84,8 +90,125 @@ def __getitem__(self, idx): |
84 | 90 | return self._data[idx] |
85 | 91 | return ArkoudaCategoricalArray(self._data[idx]) |
86 | 92 |
|
87 | | - def astype(self, x, dtype): |
88 | | - raise NotImplementedError("array_api.astype is not implemented in Arkouda yet") |
| 93 | + @overload |
| 94 | + def astype(self, dtype: np.dtype[Any], copy: bool = True) -> NDArray[Any]: ... |
| 95 | + |
| 96 | + @overload |
| 97 | + def astype(self, dtype: ExtensionDtype, copy: bool = True) -> ExtensionArray: ... |
| 98 | + |
| 99 | + @overload |
| 100 | + def astype(self, dtype: Any, copy: bool = True) -> Union[ExtensionArray, NDArray[Any]]: ... |
| 101 | + |
| 102 | + def astype( |
| 103 | + self, |
| 104 | + dtype: Any, |
| 105 | + copy: bool = True, |
| 106 | + ) -> Union[ExtensionArray, NDArray[Any]]: |
| 107 | + """ |
| 108 | + Cast to a specified dtype. |
| 109 | +
|
| 110 | + * If ``dtype`` is categorical (pandas ``category`` / ``CategoricalDtype`` / |
| 111 | + ``ArkoudaCategoricalDtype``), returns an Arkouda-backed |
| 112 | + ``ArkoudaCategoricalArray`` (optionally copied). |
| 113 | + * If ``dtype`` requests ``object``, returns a NumPy ``ndarray`` of dtype object |
| 114 | + containing the category labels (materialized to the client). |
| 115 | + * If ``dtype`` requests a string dtype, returns an Arkouda-backed |
| 116 | + ``ArkoudaStringArray`` containing the labels as strings. |
| 117 | + * Otherwise, casts the labels (as strings) to the requested dtype and returns an |
| 118 | + Arkouda-backed ExtensionArray. |
| 119 | +
|
| 120 | + Parameters |
| 121 | + ---------- |
| 122 | + dtype : Any |
| 123 | + Target dtype. |
| 124 | + copy : bool |
| 125 | + Whether to force a copy when possible. If categorical-to-categorical and |
| 126 | + ``copy=True``, attempts to copy the underlying Arkouda ``Categorical`` (if |
| 127 | + supported). Default is True. |
| 128 | +
|
| 129 | + Returns |
| 130 | + ------- |
| 131 | + Union[ExtensionArray, NDArray[Any]] |
| 132 | + The cast result. Returns a NumPy array only when casting to ``object``; |
| 133 | + otherwise returns an Arkouda-backed ExtensionArray. |
| 134 | +
|
| 135 | + Examples |
| 136 | + -------- |
| 137 | + Casting to ``category`` returns an Arkouda-backed categorical array: |
| 138 | +
|
| 139 | + >>> import arkouda as ak |
| 140 | + >>> from arkouda.pandas.extension import ArkoudaCategoricalArray |
| 141 | + >>> c = ArkoudaCategoricalArray(ak.Categorical(ak.array(["x", "y", "x"]))) |
| 142 | + >>> out = c.astype("category") |
| 143 | + >>> out is c |
| 144 | + False |
| 145 | +
|
| 146 | + Forcing a copy when casting to the same categorical dtype returns a new array: |
| 147 | +
|
| 148 | + >>> out2 = c.astype("category", copy=True) |
| 149 | + >>> out2 is c |
| 150 | + False |
| 151 | + >>> out2.to_ndarray() |
| 152 | + array(['x', 'y', 'x'], dtype='<U...') |
| 153 | +
|
| 154 | + Casting to ``object`` materializes the category labels to a NumPy object array: |
| 155 | +
|
| 156 | + >>> c.astype(object) |
| 157 | + array(['x', 'y', 'x'], dtype=object) |
| 158 | +
|
| 159 | + Casting to a string dtype returns an Arkouda-backed string array of labels: |
| 160 | +
|
| 161 | + >>> s = c.astype("string") |
| 162 | + >>> s.to_ndarray() |
| 163 | + array(['x', 'y', 'x'], dtype='<U1') |
| 164 | +
|
| 165 | + Casting to another dtype casts the labels-as-strings and returns an Arkouda-backed array: |
| 166 | +
|
| 167 | + >>> c_num = ArkoudaCategoricalArray(ak.Categorical(ak.array(["1", "2", "3"]))) |
| 168 | + >>> a = c_num.astype("int64") |
| 169 | + >>> a.to_ndarray() |
| 170 | + array([1, 2, 3]) |
| 171 | + """ |
| 172 | + from arkouda.numpy._typing._typing import is_string_dtype_hint |
| 173 | + |
| 174 | + # --- 1) ExtensionDtype branch first: proves overload #2 returns ExtensionArray --- |
| 175 | + if isinstance(dtype, ExtensionDtype): |
| 176 | + if hasattr(dtype, "numpy_dtype"): |
| 177 | + dtype = dtype.numpy_dtype |
| 178 | + |
| 179 | + if isinstance(dtype, (ArkoudaCategoricalDtype, pd_CategoricalDtype)) or dtype in ( |
| 180 | + "category", |
| 181 | + ): |
| 182 | + if not copy: |
| 183 | + return self |
| 184 | + data = self._data.copy() if hasattr(self._data, "copy") else self._data |
| 185 | + return type_cast(ExtensionArray, type(self)(data)) |
| 186 | + |
| 187 | + data = self._data.to_strings() |
| 188 | + |
| 189 | + if isinstance(dtype, pd_StringDtype) or is_string_dtype_hint(dtype): |
| 190 | + return type_cast(ExtensionArray, ArkoudaStringArray._from_sequence(data)) |
| 191 | + |
| 192 | + casted = data.astype(dtype) |
| 193 | + return type_cast(ExtensionArray, ArkoudaExtensionArray._from_sequence(casted)) |
| 194 | + |
| 195 | + # --- 2) object -> numpy --- |
| 196 | + if dtype in (object, np.object_, "object", np.dtype("O")): |
| 197 | + return self.to_ndarray().astype(object, copy=False) |
| 198 | + |
| 199 | + if isinstance(dtype, (ArkoudaCategoricalDtype, pd_CategoricalDtype)) or dtype in ("category",): |
| 200 | + if not copy: |
| 201 | + return self |
| 202 | + data = self._data.copy() if hasattr(self._data, "copy") else self._data |
| 203 | + return type(self)(data) |
| 204 | + |
| 205 | + data = self._data.to_strings() |
| 206 | + |
| 207 | + if isinstance(dtype, pd_StringDtype) or is_string_dtype_hint(dtype): |
| 208 | + return ArkoudaStringArray._from_sequence(data) |
| 209 | + |
| 210 | + casted = data.astype(dtype) |
| 211 | + return ArkoudaExtensionArray._from_sequence(casted) |
89 | 212 |
|
90 | 213 | def isna(self): |
91 | 214 | return ak.zeros(self._data.size, dtype=ak.bool) |
|
0 commit comments