diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 3bef155f377b0..55eddb8045ca6 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -34,6 +34,9 @@ class ArrowStringArrayMixin: def __init__(self, *args, **kwargs) -> None: raise NotImplementedError + def _from_pyarrow_array(self, pa_array) -> Self: + raise NotImplementedError + def _convert_bool_result(self, result, na=lib.no_default, method_name=None): # Convert a bool-dtype result to the appropriate result type raise NotImplementedError @@ -50,31 +53,31 @@ def _str_len(self): return self._convert_int_result(result) def _str_lower(self) -> Self: - return type(self)(pc.utf8_lower(self._pa_array)) + return self._from_pyarrow_array(pc.utf8_lower(self._pa_array)) def _str_upper(self) -> Self: - return type(self)(pc.utf8_upper(self._pa_array)) + return self._from_pyarrow_array(pc.utf8_upper(self._pa_array)) def _str_strip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_trim_whitespace(self._pa_array) else: result = pc.utf8_trim(self._pa_array, characters=to_strip) - return type(self)(result) + return self._from_pyarrow_array(result) def _str_lstrip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_ltrim_whitespace(self._pa_array) else: result = pc.utf8_ltrim(self._pa_array, characters=to_strip) - return type(self)(result) + return self._from_pyarrow_array(result) def _str_rstrip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_rtrim_whitespace(self._pa_array) else: result = pc.utf8_rtrim(self._pa_array, characters=to_strip) - return type(self)(result) + return self._from_pyarrow_array(result) def _str_pad( self, @@ -104,7 +107,9 @@ def _str_pad( raise ValueError( f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" ) - return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar)) + return self._from_pyarrow_array( + pa_pad(self._pa_array, width=width, padding=fillchar) + ) def _str_get(self, i: int) -> Self: lengths = pc.utf8_length(self._pa_array) @@ -124,7 +129,7 @@ def _str_get(self, i: int) -> Self: ) null_value = pa.scalar(None, type=self._pa_array.type) result = pc.if_else(not_out_of_bounds, selected, null_value) - return type(self)(result) + return self._from_pyarrow_array(result) def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None @@ -132,7 +137,9 @@ def _str_slice( if pa_version_under13p0: # GH#59724 result = self._apply_elementwise(lambda val: val[start:stop:step]) - return type(self)(pa.chunked_array(result, type=self._pa_array.type)) + return self._from_pyarrow_array( + pa.chunked_array(result, type=self._pa_array.type) + ) if start is None: if step is not None and step < 0: # GH#59710 @@ -141,7 +148,7 @@ def _str_slice( start = 0 if step is None: step = 1 - return type(self)( + return self._from_pyarrow_array( pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step) ) @@ -154,7 +161,9 @@ def _str_slice_replace( start = 0 if stop is None: stop = np.iinfo(np.int64).max - return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl)) + return self._from_pyarrow_array( + pc.utf8_replace_slice(self._pa_array, start, stop, repl) + ) def _str_replace( self, @@ -181,32 +190,32 @@ def _str_replace( replacement=repl, max_replacements=pa_max_replacements, ) - return type(self)(result) + return self._from_pyarrow_array(result) def _str_capitalize(self) -> Self: - return type(self)(pc.utf8_capitalize(self._pa_array)) + return self._from_pyarrow_array(pc.utf8_capitalize(self._pa_array)) def _str_title(self) -> Self: - return type(self)(pc.utf8_title(self._pa_array)) + return self._from_pyarrow_array(pc.utf8_title(self._pa_array)) def _str_swapcase(self) -> Self: - return type(self)(pc.utf8_swapcase(self._pa_array)) + return self._from_pyarrow_array(pc.utf8_swapcase(self._pa_array)) def _str_removeprefix(self, prefix: str): if not pa_version_under13p0: starts_with = pc.starts_with(self._pa_array, pattern=prefix) removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix)) result = pc.if_else(starts_with, removed, self._pa_array) - return type(self)(result) + return self._from_pyarrow_array(result) predicate = lambda val: val.removeprefix(prefix) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_removesuffix(self, suffix: str): ends_with = pc.ends_with(self._pa_array, pattern=suffix) removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix)) result = pc.if_else(ends_with, removed, self._pa_array) - return type(self)(result) + return self._from_pyarrow_array(result) def _str_startswith( self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d9890fb331cfa..221191773186e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -390,6 +390,13 @@ def _from_sequence_of_strings( ) return cls._from_sequence(scalars, dtype=pa_type, copy=copy) + def _from_pyarrow_array(self, pa_array): + """ + Construct from the pyarrow array result of an operation, for + compatibility with ArrowStringArray. + """ + return type(self)(pa_array) + def _cast_pointwise_result(self, values) -> ArrayLike: if len(values) == 0: # Retain our dtype @@ -448,14 +455,14 @@ def _cast_pointwise_result(self, values) -> ArrayLike: if isinstance(self.dtype, StringDtype): if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type): - # ArrowStringArrayNumpySemantics - return type(self)(arr).astype(self.dtype) + # ArrowStringArray preserves dtype.na_value + return self._from_pyarrow_array(arr) if self.dtype.na_value is np.nan: # ArrowEA has different semantics, so we return numpy-based # result instead return super()._cast_pointwise_result(values) return ArrowExtensionArray(arr) - return type(self)(arr) + return self._from_pyarrow_array(arr) @classmethod def _box_pa( @@ -677,11 +684,13 @@ def __getitem__(self, item: PositionalIndexer): pa_dtype = pa.string() else: pa_dtype = self._dtype.pyarrow_dtype - return type(self)(pa.chunked_array([], type=pa_dtype)) + result = pa.chunked_array([], type=pa_dtype) + return self._from_pyarrow_array(result) + elif item.dtype.kind in "iu": return self.take(item) elif item.dtype.kind == "b": - return type(self)(self._pa_array.filter(item)) + return self._from_pyarrow_array(self._pa_array.filter(item)) else: raise IndexError( "Only integers, slices and integer or " @@ -717,7 +726,7 @@ def __getitem__(self, item: PositionalIndexer): value = self._pa_array[item] if isinstance(value, pa.ChunkedArray): - return type(self)(value) + return self._from_pyarrow_array(value) else: pa_type = self._pa_array.type scalar = value.as_py() @@ -774,28 +783,28 @@ def __array__( def __invert__(self) -> Self: # This is a bit wise op for integer types if pa.types.is_integer(self._pa_array.type): - return type(self)(pc.bit_wise_not(self._pa_array)) + return self._from_pyarrow_array(pc.bit_wise_not(self._pa_array)) elif pa.types.is_string(self._pa_array.type) or pa.types.is_large_string( self._pa_array.type ): # Raise TypeError instead of pa.ArrowNotImplementedError raise TypeError("__invert__ is not supported for string dtypes") else: - return type(self)(pc.invert(self._pa_array)) + return self._from_pyarrow_array(pc.invert(self._pa_array)) def __neg__(self) -> Self: try: - return type(self)(pc.negate_checked(self._pa_array)) + return self._from_pyarrow_array(pc.negate_checked(self._pa_array)) except pa.ArrowNotImplementedError as err: raise TypeError( f"unary '-' not supported for dtype '{self.dtype}'" ) from err def __pos__(self) -> Self: - return type(self)(self._pa_array) + return self._from_pyarrow_array(self._pa_array) def __abs__(self) -> Self: - return type(self)(pc.abs_checked(self._pa_array)) + return self._from_pyarrow_array(pc.abs_checked(self._pa_array)) # GH 42600: __getstate__/__setstate__ not necessary once # https://issues.apache.org/jira/browse/ARROW-10739 is addressed @@ -873,7 +882,7 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self: raise TypeError( self._op_method_error_message(other_original, op) ) from err - return type(self)(result) + return self._from_pyarrow_array(result) elif op in [operator.mul, roperator.rmul]: binary = self._pa_array integral = other @@ -881,7 +890,7 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self: raise TypeError("Can only string multiply by an integer.") pa_integral = pc.if_else(pc.less(integral, 0), 0, integral) result = pc.binary_repeat(binary, pa_integral) - return type(self)(result) + return self._from_pyarrow_array(result) elif ( pa.types.is_string(other.type) or pa.types.is_binary(other.type) @@ -893,7 +902,7 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self: raise TypeError("Can only string multiply by an integer.") pa_integral = pc.if_else(pc.less(integral, 0), 0, integral) result = pc.binary_repeat(binary, pa_integral) - return type(self)(result) + return self._from_pyarrow_array(result) if ( isinstance(other, pa.Scalar) and pc.is_null(other).as_py() @@ -912,7 +921,7 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self: result = pc_func(self._pa_array, other) except pa.ArrowNotImplementedError as err: raise TypeError(self._op_method_error_message(other_original, op)) from err - return type(self)(result) + return self._from_pyarrow_array(result) def _logical_method(self, other, op) -> Self: # For integer types `^`, `|`, `&` are bitwise operators and return @@ -1185,7 +1194,7 @@ def copy(self) -> Self: ------- type(self) """ - return type(self)(self._pa_array) + return self._from_pyarrow_array(self._pa_array) def dropna(self) -> Self: """ @@ -1195,7 +1204,7 @@ def dropna(self) -> Self: ------- ArrowExtensionArray """ - return type(self)(pc.drop_null(self._pa_array)) + return self._from_pyarrow_array(pc.drop_null(self._pa_array)) def _pad_or_backfill( self, @@ -1212,9 +1221,13 @@ def _pad_or_backfill( method = missing.clean_fill_method(method) try: if method == "pad": - return type(self)(pc.fill_null_forward(self._pa_array)) + return self._from_pyarrow_array( + pc.fill_null_forward(self._pa_array) + ) elif method == "backfill": - return type(self)(pc.fill_null_backward(self._pa_array)) + return self._from_pyarrow_array( + pc.fill_null_backward(self._pa_array) + ) except pa.ArrowNotImplementedError: # ArrowNotImplementedError: Function 'coalesce' has no kernel # matching input types (duration[ns], duration[ns]) @@ -1258,7 +1271,9 @@ def fillna( raise TypeError(msg) from err try: - return type(self)(pc.fill_null(self._pa_array, fill_value=fill_value)) + return self._from_pyarrow_array( + pc.fill_null(self._pa_array, fill_value=fill_value) + ) except pa.ArrowNotImplementedError: # ArrowNotImplementedError: Function 'coalesce' has no kernel # matching input types (duration[ns], duration[ns]) @@ -1315,7 +1330,9 @@ def factorize( encoded = data.dictionary_encode(null_encoding=null_encoding) if encoded.length() == 0: indices = np.array([], dtype=np.intp) - uniques = type(self)(pa.chunked_array([], type=encoded.type.value_type)) + uniques = self._from_pyarrow_array( + pa.chunked_array([], type=encoded.type.value_type) + ) else: # GH 54844 combined = encoded.combine_chunks() @@ -1325,7 +1342,7 @@ def factorize( indices = pa_indices.to_numpy(zero_copy_only=False, writable=True).astype( np.intp, copy=False ) - uniques = type(self)(combined.dictionary) + uniques = self._from_pyarrow_array(combined.dictionary) return indices, uniques @@ -1357,7 +1374,7 @@ def round(self, decimals: int = 0, *args, **kwargs) -> Self: DataFrame.round : Round values of a DataFrame. Series.round : Round values of a Series. """ - return type(self)(pc.round(self._pa_array, ndigits=decimals)) + return self._from_pyarrow_array(pc.round(self._pa_array, ndigits=decimals)) @doc(ExtensionArray.searchsorted) def searchsorted( @@ -1459,23 +1476,23 @@ def take( indices_array = pa.array(indices_array, mask=fill_mask) result = self._pa_array.take(indices_array) if isna(fill_value): - return type(self)(result) + return self._from_pyarrow_array(result) # TODO: ArrowNotImplementedError: Function fill_null has no # kernel matching input types (array[string], scalar[string]) - result = type(self)(result) + result = self._from_pyarrow_array(result) result[fill_mask] = fill_value return result # return type(self)(pc.fill_null(result, pa.scalar(fill_value))) else: # Nothing to fill - return type(self)(self._pa_array.take(indices)) + return self._from_pyarrow_array(self._pa_array.take(indices)) else: # allow_fill=False # TODO(ARROW-9432): Treat negative indices as indices from the right. if (indices_array < 0).any(): # Don't modify in-place indices_array = np.copy(indices_array) indices_array[indices_array < 0] += len(self._pa_array) - return type(self)(self._pa_array.take(indices_array)) + return self._from_pyarrow_array(self._pa_array.take(indices_array)) def _maybe_convert_datelike_array(self): """Maybe convert to a datelike array.""" @@ -1614,7 +1631,7 @@ def unique(self) -> Self: ArrowExtensionArray """ pa_result = pc.unique(self._pa_array) - return type(self)(pa_result) + return self._from_pyarrow_array(pa_result) def value_counts(self, dropna: bool = True) -> Series: """ @@ -1650,7 +1667,7 @@ def value_counts(self, dropna: bool = True) -> Series: counts = ArrowExtensionArray(counts) - index = Index(type(self)(values)) + index = Index(self._from_pyarrow_array(values)) return Series(counts, index=index, name="count", copy=False) @@ -1674,7 +1691,7 @@ def _concat_same_type(cls, to_concat) -> Self: else: pa_dtype = to_concat[0].dtype.pyarrow_dtype arr = pa.chunked_array(chunks, type=pa_dtype) - return cls(arr) + return to_concat[0]._from_pyarrow_array(arr) def _accumulate( self, name: str, *, skipna: bool = True, **kwargs @@ -1742,7 +1759,7 @@ def _accumulate( if convert_to_int: result = result.cast(pa_dtype) - return type(self)(result) + return self._from_pyarrow_array(result) def _str_accumulate( self, name: str, *, skipna: bool = True, **kwargs @@ -1769,7 +1786,7 @@ def _str_accumulate( if self._hasna: na_mask = pc.is_null(pa_array) if pc.all(na_mask) == pa.scalar(True): - return type(self)(pa_array) + return self._from_pyarrow_array(pa_array) if skipna: if name == "cumsum": pa_array = pc.fill_null(pa_array, "") @@ -1792,7 +1809,7 @@ def _str_accumulate( elif na_mask is not None: pa_result = pc.if_else(na_mask, None, pa_result) - result = type(self)(pa_result) + result = self._from_pyarrow_array(pa_result) return result def _reduce_pyarrow(self, name: str, *, skipna: bool = True, **kwargs) -> pa.Scalar: @@ -1979,7 +1996,7 @@ def _reduce( """ result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) if isinstance(result, pa.Array): - return type(self)(result) + return self._from_pyarrow_array(result) else: return result @@ -2025,11 +2042,11 @@ def _explode(self): # pc.if_else here is similar to `values[mask] = fill_value` # but this avoids an object-dtype round-trip. pa_values = pc.if_else(~mask, values._pa_array, fill_value) - values = type(self)(pa_values) + values = self._from_pyarrow_array(pa_values) counts = counts.copy() counts[mask] = 1 values = values.fillna(fill_value) - values = type(self)(pa.compute.list_flatten(values._pa_array)) + values = self._from_pyarrow_array(pa.compute.list_flatten(values._pa_array)) return values, counts def __setitem__(self, key, value) -> None: @@ -2235,7 +2252,7 @@ def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self: result = result.cast(pa.int64()) result = result.cast(pa_dtype) - return type(self)(result) + return self._from_pyarrow_array(result) def _mode(self, dropna: bool = True) -> Self: """ @@ -2277,7 +2294,7 @@ def _mode(self, dropna: bool = True) -> Self: most_common = most_common.cast(pa_type) most_common = most_common.take(pc.array_sort_indices(most_common)) - return type(self)(most_common) + return self._from_pyarrow_array(most_common) def _maybe_convert_setitem_value(self, value): """Maybe convert value to be pyarrow compatible.""" @@ -2319,7 +2336,7 @@ def interpolate( y_diff_2 = pc.fill_null_backward(pc.pairwise_diff_checked(values, period=2)) prev_values = pa.concat_arrays([na_value, values[:-2], na_value]) interps = pc.add_checked(prev_values, pc.divide_checked(y_diff_2, 2)) - return type(self)(pc.coalesce(self._pa_array, interps)) + return self._from_pyarrow_array(pc.coalesce(self._pa_array, interps)) mask = self.isna() if self.dtype.kind == "f": @@ -2342,7 +2359,7 @@ def interpolate( mask=mask, **kwargs, ) - return type(self)(self._box_pa_array(pa.array(data, mask=mask))) + return self._from_pyarrow_array(self._box_pa_array(pa.array(data, mask=mask))) @classmethod def _if_else( @@ -2505,11 +2522,11 @@ def _groupby_op( return result elif isinstance(result, BaseMaskedArray): pa_result = result.__arrow_array__() - return type(self)(pa_result) + return self._from_pyarrow_array(pa_result) else: # DatetimeArray, TimedeltaArray pa_result = pa.array(result, from_pandas=True) - return type(self)(pa_result) + return self._from_pyarrow_array(pa_result) def _apply_elementwise(self, func: Callable) -> list[list[Any]]: """Apply a callable to each element while maintaining the chunking structure.""" @@ -2524,25 +2541,25 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]: def _convert_bool_result(self, result, na=lib.no_default, method_name=None): if na is not lib.no_default and not isna(na): # pyright: ignore [reportGeneralTypeIssues] result = result.fill_null(na) - return type(self)(result) + return self._from_pyarrow_array(result) def _convert_int_result(self, result): - return type(self)(result) + return self._from_pyarrow_array(result) def _convert_rank_result(self, result): - return type(self)(result) + return self._from_pyarrow_array(result) def _str_count(self, pat: str, flags: int = 0) -> Self: if flags: raise NotImplementedError(f"count not implemented with {flags=}") - return type(self)(pc.count_substring_regex(self._pa_array, pat)) + return self._from_pyarrow_array(pc.count_substring_regex(self._pa_array, pat)) def _str_repeat(self, repeats: int | Sequence[int]) -> Self: if not isinstance(repeats, int): raise NotImplementedError( f"repeat is not implemented when repeats is {type(repeats).__name__}" ) - return type(self)(pc.binary_repeat(self._pa_array, repeats)) + return self._from_pyarrow_array(pc.binary_repeat(self._pa_array, repeats)) def _str_join(self, sep: str) -> Self: if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string( @@ -2552,27 +2569,27 @@ def _str_join(self, sep: str) -> Self: result = pa.chunked_array(result, type=pa.list_(pa.string())) else: result = self._pa_array - return type(self)(pc.binary_join(result, sep)) + return self._from_pyarrow_array(pc.binary_join(result, sep)) def _str_partition(self, sep: str, expand: bool) -> Self: predicate = lambda val: val.partition(sep) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_rpartition(self, sep: str, expand: bool) -> Self: predicate = lambda val: val.rpartition(sep) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_casefold(self) -> Self: predicate = lambda val: val.casefold() result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_encode(self, encoding: str, errors: str = "strict") -> Self: predicate = lambda val: val.encode(encoding, errors) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): if flags: @@ -2583,7 +2600,7 @@ def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): result = pc.extract_regex(self._pa_array, pat) if expand: return { - col: type(self)(pc.struct_field(result, [i])) + col: self._from_pyarrow_array(pc.struct_field(result, [i])) for col, i in zip(groups, range(result.type.num_fields)) } else: @@ -2593,7 +2610,7 @@ def _str_findall(self, pat: str, flags: int = 0) -> Self: regex = re.compile(pat, flags=flags) predicate = lambda val: regex.findall(val) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): if dtype is None: @@ -2616,28 +2633,28 @@ def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): dummies = np.zeros(n_rows * n_cols, dtype=dummies_dtype) dummies[indices] = True dummies = dummies.reshape((n_rows, n_cols)) # type: ignore[assignment] - result = type(self)(pa.array(list(dummies))) + result = self._from_pyarrow_array(pa.array(list(dummies))) return result, uniques_sorted.to_pylist() def _str_index(self, sub: str, start: int = 0, end: int | None = None) -> Self: predicate = lambda val: val.index(sub, start, end) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_rindex(self, sub: str, start: int = 0, end: int | None = None) -> Self: predicate = lambda val: val.rindex(sub, start, end) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_normalize(self, form: Literal["NFC", "NFD", "NFKC", "NFKD"]) -> Self: predicate = lambda val: unicodedata.normalize(form, val) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_rfind(self, sub: str, start: int = 0, end=None) -> Self: predicate = lambda val: val.rfind(sub, start, end) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_split( self, @@ -2654,34 +2671,34 @@ def _str_split( split_func = functools.partial(pc.split_pattern_regex, pattern=pat) else: split_func = functools.partial(pc.split_pattern, pattern=pat) - return type(self)(split_func(self._pa_array, max_splits=n)) + return self._from_pyarrow_array(split_func(self._pa_array, max_splits=n)) def _str_rsplit(self, pat: str | None = None, n: int | None = -1) -> Self: if n in {-1, 0}: n = None if pat is None: - return type(self)( + return self._from_pyarrow_array( pc.utf8_split_whitespace(self._pa_array, max_splits=n, reverse=True) ) - return type(self)( + return self._from_pyarrow_array( pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) ) def _str_translate(self, table: dict[int, str]) -> Self: predicate = lambda val: val.translate(table) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) def _str_wrap(self, width: int, **kwargs) -> Self: kwargs["width"] = width tw = textwrap.TextWrapper(**kwargs) predicate = lambda val: "\n".join(tw.wrap(val)) result = self._apply_elementwise(predicate) - return type(self)(pa.chunked_array(result)) + return self._from_pyarrow_array(pa.chunked_array(result)) @property def _dt_days(self) -> Self: - return type(self)( + return self._from_pyarrow_array( pa.array( self._to_timedeltaarray().components.days, from_pandas=True, @@ -2691,7 +2708,7 @@ def _dt_days(self) -> Self: @property def _dt_hours(self) -> Self: - return type(self)( + return self._from_pyarrow_array( pa.array( self._to_timedeltaarray().components.hours, from_pandas=True, @@ -2701,7 +2718,7 @@ def _dt_hours(self) -> Self: @property def _dt_minutes(self) -> Self: - return type(self)( + return self._from_pyarrow_array( pa.array( self._to_timedeltaarray().components.minutes, from_pandas=True, @@ -2711,7 +2728,7 @@ def _dt_minutes(self) -> Self: @property def _dt_seconds(self) -> Self: - return type(self)( + return self._from_pyarrow_array( pa.array( self._to_timedeltaarray().components.seconds, from_pandas=True, @@ -2721,7 +2738,7 @@ def _dt_seconds(self) -> Self: @property def _dt_milliseconds(self) -> Self: - return type(self)( + return self._from_pyarrow_array( pa.array( self._to_timedeltaarray().components.milliseconds, from_pandas=True, @@ -2731,7 +2748,7 @@ def _dt_milliseconds(self) -> Self: @property def _dt_microseconds(self) -> Self: - return type(self)( + return self._from_pyarrow_array( pa.array( self._to_timedeltaarray().components.microseconds, from_pandas=True, @@ -2741,7 +2758,7 @@ def _dt_microseconds(self) -> Self: @property def _dt_nanoseconds(self) -> Self: - return type(self)( + return self._from_pyarrow_array( pa.array( self._to_timedeltaarray().components.nanoseconds, from_pandas=True, @@ -2756,52 +2773,60 @@ def _dt_to_pytimedelta(self) -> np.ndarray: return np.array(data, dtype=object) def _dt_total_seconds(self) -> Self: - return type(self)( - pa.array(self._to_timedeltaarray().total_seconds(), from_pandas=True) - ) + result = pa.array(self._to_timedeltaarray().total_seconds(), from_pandas=True) + return self._from_pyarrow_array(result) def _dt_as_unit(self, unit: str) -> Self: if pa.types.is_date(self.dtype.pyarrow_dtype): raise NotImplementedError("as_unit not implemented for date types") pd_array = self._maybe_convert_datelike_array() # Don't just cast _pa_array in order to follow pandas unit conversion rules - return type(self)(pa.array(pd_array.as_unit(unit), from_pandas=True)) + result = pa.array(pd_array.as_unit(unit), from_pandas=True) + return self._from_pyarrow_array(result) @property def _dt_year(self) -> Self: - return type(self)(pc.year(self._pa_array)) + result = pc.year(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_day(self) -> Self: - return type(self)(pc.day(self._pa_array)) + result = pc.day(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_day_of_week(self) -> Self: - return type(self)(pc.day_of_week(self._pa_array)) + result = pc.day_of_week(self._pa_array) + return self._from_pyarrow_array(result) _dt_dayofweek = _dt_day_of_week _dt_weekday = _dt_day_of_week @property def _dt_day_of_year(self) -> Self: - return type(self)(pc.day_of_year(self._pa_array)) + result = pc.day_of_year(self._pa_array) + return self._from_pyarrow_array(result) _dt_dayofyear = _dt_day_of_year @property def _dt_hour(self) -> Self: - return type(self)(pc.hour(self._pa_array)) + result = pc.hour(self._pa_array) + return self._from_pyarrow_array(result) def _dt_isocalendar(self) -> Self: - return type(self)(pc.iso_calendar(self._pa_array)) + result = pc.iso_calendar(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_is_leap_year(self) -> Self: - return type(self)(pc.is_leap_year(self._pa_array)) + result = pc.is_leap_year(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_is_month_start(self) -> Self: - return type(self)(pc.equal(pc.day(self._pa_array), 1)) + result = pc.equal(pc.day(self._pa_array), 1) + return self._from_pyarrow_array(result) @property def _dt_is_month_end(self) -> Self: @@ -2812,25 +2837,23 @@ def _dt_is_month_end(self) -> Self: ), 1, ) - return type(self)(result) + return self._from_pyarrow_array(result) @property def _dt_is_year_start(self) -> Self: - return type(self)( - pc.and_( - pc.equal(pc.month(self._pa_array), 1), - pc.equal(pc.day(self._pa_array), 1), - ) + result = pc.and_( + pc.equal(pc.month(self._pa_array), 1), + pc.equal(pc.day(self._pa_array), 1), ) + return self._from_pyarrow_array(result) @property def _dt_is_year_end(self) -> Self: - return type(self)( - pc.and_( - pc.equal(pc.month(self._pa_array), 12), - pc.equal(pc.day(self._pa_array), 31), - ) + result = pc.and_( + pc.equal(pc.month(self._pa_array), 12), + pc.equal(pc.day(self._pa_array), 31), ) + return self._from_pyarrow_array(result) @property def _dt_is_quarter_start(self) -> Self: @@ -2838,7 +2861,7 @@ def _dt_is_quarter_start(self) -> Self: pc.floor_temporal(self._pa_array, unit="quarter"), pc.floor_temporal(self._pa_array, unit="day"), ) - return type(self)(result) + return self._from_pyarrow_array(result) @property def _dt_is_quarter_end(self) -> Self: @@ -2849,7 +2872,7 @@ def _dt_is_quarter_end(self) -> Self: ), 1, ) - return type(self)(result) + return self._from_pyarrow_array(result) @property def _dt_days_in_month(self) -> Self: @@ -2857,7 +2880,7 @@ def _dt_days_in_month(self) -> Self: pc.floor_temporal(self._pa_array, unit="month"), pc.ceil_temporal(self._pa_array, unit="month"), ) - return type(self)(result) + return self._from_pyarrow_array(result) _dt_daysinmonth = _dt_days_in_month @@ -2866,31 +2889,38 @@ def _dt_microsecond(self) -> Self: # GH 59154 us = pc.microsecond(self._pa_array) ms_to_us = pc.multiply(pc.millisecond(self._pa_array), 1000) - return type(self)(pc.add(us, ms_to_us)) + result = pc.add(us, ms_to_us) + return self._from_pyarrow_array(result) @property def _dt_minute(self) -> Self: - return type(self)(pc.minute(self._pa_array)) + result = pc.minute(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_month(self) -> Self: - return type(self)(pc.month(self._pa_array)) + result = pc.month(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_nanosecond(self) -> Self: - return type(self)(pc.nanosecond(self._pa_array)) + result = pc.nanosecond(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_quarter(self) -> Self: - return type(self)(pc.quarter(self._pa_array)) + result = pc.quarter(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_second(self) -> Self: - return type(self)(pc.second(self._pa_array)) + result = pc.second(self._pa_array) + return self._from_pyarrow_array(result) @property def _dt_date(self) -> Self: - return type(self)(self._pa_array.cast(pa.date32())) + result = self._pa_array.cast(pa.date32()) + return self._from_pyarrow_array(result) @property def _dt_time(self) -> Self: @@ -2899,7 +2929,8 @@ def _dt_time(self) -> Self: if self.dtype.pyarrow_dtype.unit in {"us", "ns"} else "ns" ) - return type(self)(self._pa_array.cast(pa.time64(unit))) + result = self._pa_array.cast(pa.time64(unit)) + return self._from_pyarrow_array(result) @property def _dt_tz(self): @@ -2910,10 +2941,12 @@ def _dt_unit(self): return self.dtype.pyarrow_dtype.unit def _dt_normalize(self) -> Self: - return type(self)(pc.floor_temporal(self._pa_array, 1, "day")) + result = pc.floor_temporal(self._pa_array, 1, "day") + return self._from_pyarrow_array(result) def _dt_strftime(self, format: str) -> Self: - return type(self)(pc.strftime(self._pa_array, format=format)) + result = pc.strftime(self._pa_array, format=format) + return self._from_pyarrow_array(result) def _round_temporally( self, @@ -2950,7 +2983,8 @@ def _round_temporally( raise ValueError(f"{freq=} is not supported") multiple = offset.n rounding_method = getattr(pc, f"{method}_temporal") - return type(self)(rounding_method(self._pa_array, multiple=multiple, unit=unit)) + result = rounding_method(self._pa_array, multiple=multiple, unit=unit) + return self._from_pyarrow_array(result) def _dt_ceil( self, @@ -2979,12 +3013,14 @@ def _dt_round( def _dt_day_name(self, locale: str | None = None) -> Self: if locale is None: locale = "C" - return type(self)(pc.strftime(self._pa_array, format="%A", locale=locale)) + result = pc.strftime(self._pa_array, format="%A", locale=locale) + return self._from_pyarrow_array(result) def _dt_month_name(self, locale: str | None = None) -> Self: if locale is None: locale = "C" - return type(self)(pc.strftime(self._pa_array, format="%B", locale=locale)) + result = pc.strftime(self._pa_array, format="%B", locale=locale) + return self._from_pyarrow_array(result) def _dt_to_pydatetime(self) -> Series: from pandas import Series @@ -3023,7 +3059,7 @@ def _dt_tz_localize( result = pc.assume_timezone( self._pa_array, str(tz), ambiguous=ambiguous, nonexistent=nonexistent_pa ) - return type(self)(result) + return self._from_pyarrow_array(result) def _dt_tz_convert(self, tz) -> Self: if self.dtype.pyarrow_dtype.tz is None: @@ -3032,7 +3068,7 @@ def _dt_tz_convert(self, tz) -> Self: ) current_unit = self.dtype.pyarrow_dtype.unit result = self._pa_array.cast(pa.timestamp(current_unit, tz)) - return type(self)(result) + return self._from_pyarrow_array(result) def transpose_homogeneous_pyarrow( diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 4d91f33a8df87..da270da342ee6 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -293,7 +293,6 @@ def construct_array_type(self) -> type_t[BaseStringArray]: """ from pandas.core.arrays.string_arrow import ( ArrowStringArray, - ArrowStringArrayNumpySemantics, ) if self.storage == "python" and self._na_value is libmissing.NA: @@ -303,7 +302,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]: elif self.storage == "python": return StringArrayNumpySemantics else: - return ArrowStringArrayNumpySemantics + return ArrowStringArray def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: storages = set() @@ -340,16 +339,9 @@ def __from_arrow__( Construct StringArray from pyarrow Array/ChunkedArray. """ if self.storage == "pyarrow": - if self._na_value is libmissing.NA: - from pandas.core.arrays.string_arrow import ArrowStringArray - - return ArrowStringArray(array) - else: - from pandas.core.arrays.string_arrow import ( - ArrowStringArrayNumpySemantics, - ) + from pandas.core.arrays.string_arrow import ArrowStringArray - return ArrowStringArrayNumpySemantics(array) + return ArrowStringArray(array, dtype=self) else: import pyarrow @@ -493,6 +485,8 @@ def _str_map_str_or_object( result = pa.array( result, mask=mask, type=pa.large_string(), from_pandas=True ) + # error: "BaseStringArray" has no attribute "_from_pyarrow_array" + return self._from_pyarrow_array(result) # type: ignore[attr-defined] # error: Too many arguments for "BaseStringArray" return type(self)(result) # type: ignore[call-arg] diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 39e84a3abdcf5..5db658075ff82 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -91,6 +91,8 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr ---------- values : pyarrow.Array or pyarrow.ChunkedArray The array of data. + dtype : StringDtype + The dtype for the array. Attributes ---------- @@ -123,10 +125,8 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr # error: Incompatible types in assignment (expression has type "StringDtype", # base class "ArrowExtensionArray" defined the type as "ArrowDtype") _dtype: StringDtype # type: ignore[assignment] - _storage = "pyarrow" - _na_value: libmissing.NAType | float = libmissing.NA - def __init__(self, values) -> None: + def __init__(self, values, *, dtype: StringDtype | None = None) -> None: _chk_pyarrow_available() if isinstance(values, (pa.Array, pa.ChunkedArray)) and ( pa.types.is_string(values.type) @@ -143,7 +143,10 @@ def __init__(self, values) -> None: values = pc.cast(values, pa.large_string()) super().__init__(values) - self._dtype = StringDtype(storage=self._storage, na_value=self._na_value) + + if dtype is None: + dtype = StringDtype(storage="pyarrow", na_value=libmissing.NA) + self._dtype = dtype if not pa.types.is_large_string(self._pa_array.type): raise ValueError( @@ -151,6 +154,13 @@ def __init__(self, values) -> None: "large_string type" ) + def _from_pyarrow_array(self, pa_array): + """ + Construct from the pyarrow array result of an operation, retaining + self.dtype.na_value. + """ + return type(self)(pa_array, dtype=self.dtype) + @classmethod def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar: pa_scalar = super()._box_pa_scalar(value, pa_type) @@ -195,13 +205,15 @@ def _from_sequence( na_values = scalars._mask result = scalars._data result = lib.ensure_string_array(result, copy=copy, convert_na_value=False) - return cls(pa.array(result, mask=na_values, type=pa.large_string())) + pa_arr = pa.array(result, mask=na_values, type=pa.large_string()) elif isinstance(scalars, (pa.Array, pa.ChunkedArray)): - return cls(pc.cast(scalars, pa.large_string())) - - # convert non-na-likes to str - result = lib.ensure_string_array(scalars, copy=copy) - return cls(pa.array(result, type=pa.large_string(), from_pandas=True)) + pa_arr = pc.cast(scalars, pa.large_string()) + else: + # convert non-na-likes to str + result = lib.ensure_string_array(scalars, copy=copy) + pa_arr = pa.array(result, type=pa.large_string(), from_pandas=True) + # error: Argument "dtype" to "ArrowStringArray" has incompatible type + return cls(pa_arr, dtype=dtype) # type: ignore[arg-type] @classmethod def _from_sequence_of_strings( @@ -458,7 +470,7 @@ def _reduce( if name in ("argmin", "argmax") and isinstance(result, pa.Array): return self._convert_int_result(result) elif isinstance(result, pa.Array): - return type(self)(result) + return type(self)(result, dtype=self.dtype) else: return result @@ -490,7 +502,3 @@ def _cmp_method(self, other, op): def __pos__(self) -> Self: raise TypeError(f"bad operand type for unary +: '{self.dtype}'") - - -class ArrowStringArrayNumpySemantics(ArrowStringArray): - _na_value = np.nan diff --git a/pandas/core/construction.py b/pandas/core/construction.py index 46e3e47afb2ac..c909c4f3828c3 100644 --- a/pandas/core/construction.py +++ b/pandas/core/construction.py @@ -176,7 +176,7 @@ def array( NumPy array. >>> pd.array(["a", "b"], dtype=str) - + ['a', 'b'] Length: 2, dtype: str diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index b67397b99dc20..b0ae89b1b954d 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -110,7 +110,6 @@ class providing the base-class of operations. from pandas.core.arrays.string_ import StringDtype from pandas.core.arrays.string_arrow import ( ArrowStringArray, - ArrowStringArrayNumpySemantics, ) from pandas.core.base import ( PandasObject, @@ -2896,10 +2895,11 @@ def size(self) -> DataFrame | Series: dtype_backend: None | Literal["pyarrow", "numpy_nullable"] = None if isinstance(self.obj, Series): if isinstance(self.obj.array, ArrowExtensionArray): - if isinstance(self.obj.array, ArrowStringArrayNumpySemantics): - dtype_backend = None - elif isinstance(self.obj.array, ArrowStringArray): - dtype_backend = "numpy_nullable" + if isinstance(self.obj.array, ArrowStringArray): + if self.obj.array.dtype.na_value is np.nan: + dtype_backend = None + else: + dtype_backend = "numpy_nullable" else: dtype_backend = "pyarrow" elif isinstance(self.obj.array, BaseMaskedArray): diff --git a/pandas/core/series.py b/pandas/core/series.py index 6055e65c2786b..85dd1da833615 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -762,7 +762,7 @@ def values(self): array([1, 2, 3]) >>> pd.Series(list("aabc")).values - + ['a', 'a', 'b', 'c'] Length: 4, dtype: str diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 96e1cc05e284c..9dae3ae384255 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -24,7 +24,6 @@ from pandas.core.arrays.string_ import StringArrayNumpySemantics from pandas.core.arrays.string_arrow import ( ArrowStringArray, - ArrowStringArrayNumpySemantics, ) @@ -113,7 +112,7 @@ def test_repr(dtype): arr_name = "ArrowStringArray" expected = f"<{arr_name}>\n['a', , 'b']\nLength: 3, dtype: string" elif dtype.storage == "pyarrow" and dtype.na_value is np.nan: - arr_name = "ArrowStringArrayNumpySemantics" + arr_name = "ArrowStringArray" expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str" elif dtype.storage == "python" and dtype.na_value is np.nan: arr_name = "StringArrayNumpySemantics" @@ -482,10 +481,12 @@ def test_from_sequence_no_mutate(copy, cls, dtype): result = cls._from_sequence(nan_arr, dtype=dtype, copy=copy) - if cls in (ArrowStringArray, ArrowStringArrayNumpySemantics): + if cls is ArrowStringArray: import pyarrow as pa - expected = cls(pa.array(na_arr, type=pa.string(), from_pandas=True)) + expected = cls( + pa.array(na_arr, type=pa.string(), from_pandas=True), dtype=dtype + ) elif cls is StringArrayNumpySemantics: expected = cls(nan_arr) else: diff --git a/pandas/tests/arrays/string_/test_string_arrow.py b/pandas/tests/arrays/string_/test_string_arrow.py index 2b5f60ce70b4c..626e03a900316 100644 --- a/pandas/tests/arrays/string_/test_string_arrow.py +++ b/pandas/tests/arrays/string_/test_string_arrow.py @@ -14,7 +14,6 @@ ) from pandas.core.arrays.string_arrow import ( ArrowStringArray, - ArrowStringArrayNumpySemantics, ) @@ -186,9 +185,6 @@ def test_pyarrow_not_installed_raises(): with pytest.raises(ImportError, match=msg): ArrowStringArray([]) - with pytest.raises(ImportError, match=msg): - ArrowStringArrayNumpySemantics([]) - with pytest.raises(ImportError, match=msg): ArrowStringArray._from_sequence(["a", None, "b"]) diff --git a/pandas/tests/base/test_conversion.py b/pandas/tests/base/test_conversion.py index 821f51ee95ad3..2ef0e49399e21 100644 --- a/pandas/tests/base/test_conversion.py +++ b/pandas/tests/base/test_conversion.py @@ -24,7 +24,7 @@ TimedeltaArray, ) from pandas.core.arrays.string_ import StringArrayNumpySemantics -from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics +from pandas.core.arrays.string_arrow import ArrowStringArray class TestToIterable: @@ -222,9 +222,7 @@ def test_iter_box_period(self): ) def test_values_consistent(arr, expected_type, dtype, using_infer_string): if using_infer_string and dtype == "object": - expected_type = ( - ArrowStringArrayNumpySemantics if HAS_PYARROW else StringArrayNumpySemantics - ) + expected_type = ArrowStringArray if HAS_PYARROW else StringArrayNumpySemantics l_values = Series(arr)._values r_values = pd.Index(arr)._values assert type(l_values) is expected_type diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index d8203c2e2e350..34d0ee9f819a0 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -49,7 +49,7 @@ def maybe_split_array(arr, chunked): [*arrow_array[:split].chunks, *arrow_array[split:].chunks] ) assert arrow_array.num_chunks == 2 - return type(arr)(arrow_array) + return arr._from_pyarrow_array(arrow_array) @pytest.fixture(params=[True, False]) diff --git a/pandas/tests/reductions/test_reductions.py b/pandas/tests/reductions/test_reductions.py index 807cf19269c85..db27572b9da26 100644 --- a/pandas/tests/reductions/test_reductions.py +++ b/pandas/tests/reductions/test_reductions.py @@ -29,7 +29,7 @@ ) import pandas._testing as tm from pandas.core import nanops -from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics +from pandas.core.arrays.string_arrow import ArrowStringArray def get_objs(): @@ -61,7 +61,7 @@ class TestReductions: def test_ops(self, opname, obj): result = getattr(obj, opname)() if not isinstance(obj, PeriodIndex): - if isinstance(obj.values, ArrowStringArrayNumpySemantics): + if isinstance(obj.values, ArrowStringArray): # max not on the interface expected = getattr(np.array(obj.values), opname)() else: