Skip to content

Commit bfdfb5a

Browse files
committed
Closes #5228: remove type: ignore from factorize in extension module
1 parent d780663 commit bfdfb5a

File tree

3 files changed

+128
-156
lines changed

3 files changed

+128
-156
lines changed

arkouda/pandas/extension/_arkouda_extension_array.py

Lines changed: 76 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
import numpy as np
5050

51+
from numpy.typing import NDArray
5152
from pandas.api.extensions import ExtensionArray
5253

5354
from arkouda.numpy.dtypes import all_scalars
@@ -349,45 +350,47 @@ def take(self, indexer, fill_value=None, allow_fill=False):
349350
gathered = ak.where(mask, fv, self._data[idx_fix])
350351
return type(self)(gathered)
351352

352-
def factorize( # type: ignore[override]
353-
self, use_na_sentinel=True, sort=False, **kwargs
354-
) -> Tuple["ArkoudaExtensionArray", "ArkoudaExtensionArray"]:
353+
def factorize(self, use_na_sentinel=True) -> Tuple[NDArray[np.intp], "ArkoudaExtensionArray"]:
355354
"""
356-
Encode the values of this array as integer codes and uniques,
357-
similar to :func:`pandas.factorize`, but implemented with Arkouda.
355+
Encode the values of this array as integer codes and unique values.
356+
357+
This is similar to :func:`pandas.factorize`, but the grouping/factorization
358+
work is performed in Arkouda. The returned ``codes`` are a NumPy array for
359+
pandas compatibility, while ``uniques`` are returned as an ExtensionArray
360+
of the same type as ``self``.
358361
359362
Each distinct non-missing value is assigned a unique integer code.
360-
Missing values (NaN in floating dtypes) are encoded as -1 by default.
363+
For floating dtypes, ``NaN`` is treated as missing; for all other dtypes,
364+
no values are considered missing.
361365
362366
Parameters
363367
----------
364368
use_na_sentinel : bool, default True
365-
If True, missing values are encoded as -1 in the codes array.
366-
If False, missing values are assigned a valid code equal to
367-
``len(uniques)``.
368-
sort : bool, default False
369-
Whether to sort the unique values. If False, the unique values
370-
appear in the order of first appearance in the array. If True,
371-
the unique values are sorted, and codes are assigned accordingly.
372-
**kwargs
373-
Ignored for compatibility.
369+
If True, missing values are encoded as ``-1`` in the returned codes.
370+
If False, missing values are assigned the code ``len(uniques)``.
371+
(Missingness is only detected for floating dtypes via ``NaN``.)
374372
375373
Returns
376374
-------
377-
Tuple[pdarray, ArkoudaExtensionArray]
375+
(numpy.ndarray, ExtensionArray)
378376
A pair ``(codes, uniques)`` where:
379-
- ``codes`` is a NumPy ``int64`` array of factor labels, one per element.
380-
Missing values are ``-1`` if ``use_na_sentinel=True``; otherwise they
381-
receive the code ``len(uniques)``.
382-
- ``uniques`` is a NumPy array of the unique values.
377+
378+
* ``codes`` is a 1D NumPy array of dtype ``np.intp`` with the same length
379+
as this array, containing the factor codes for each element.
380+
* ``uniques`` is an ExtensionArray containing the unique (non-missing)
381+
values, with the same extension type as ``self``.
382+
383+
If ``use_na_sentinel=True``, missing values in ``codes`` are ``-1``.
384+
Otherwise they receive the code ``len(uniques)``.
383385
384386
Notes
385387
-----
386388
* Only floating-point dtypes treat ``NaN`` as missing; for other dtypes,
387-
no values are considered missing.
388-
* This method executes all grouping and factorization in Arkouda,
389-
returning results as NumPy arrays for compatibility with pandas.
390-
* Unlike pandas, string/None/null handling is not yet unified.
389+
all values are treated as non-missing.
390+
* ``uniques`` are constructed from Arkouda's unique keys and returned as
391+
``type(self)(uniques_ak)`` so that pandas internals (e.g. ``groupby``)
392+
can treat them as an ExtensionArray.
393+
* String/None/null missing-value behavior is not yet unified with pandas.
391394
392395
Examples
393396
--------
@@ -396,7 +399,7 @@ def factorize( # type: ignore[override]
396399
>>> arr = ArkoudaArray(ak.array([1, 2, 1, 3]))
397400
>>> codes, uniques = arr.factorize()
398401
>>> codes
399-
ArkoudaArray([0 1 0 2])
402+
array([0, 1, 0, 2])
400403
>>> uniques
401404
ArkoudaArray([1 2 3])
402405
"""
@@ -407,7 +410,6 @@ def factorize( # type: ignore[override]
407410
from arkouda.numpy.pdarraycreation import array as ak_array
408411
from arkouda.numpy.sorting import argsort
409412
from arkouda.numpy.strings import Strings
410-
from arkouda.pandas.extension import ArkoudaArray
411413
from arkouda.pandas.groupbyclass import GroupBy
412414

413415
# Arkouda array backing
@@ -425,7 +427,7 @@ def factorize( # type: ignore[override]
425427
sent = -1 if use_na_sentinel else 0
426428
from arkouda.numpy.pdarraycreation import full as ak_full
427429

428-
return ArkoudaArray(ak_full(n, sent, dtype=int64)), type(self)(
430+
return ak_full(n, sent, dtype=int64).to_ndarray(), type(self)(
429431
ak_array([], dtype=self.to_numpy().dtype)
430432
)
431433

@@ -437,28 +439,16 @@ def factorize( # type: ignore[override]
437439

438440
uniques_ak = concatenate(uniques_ak)
439441

440-
if sort:
441-
# Keys already sorted; group id -> 0..k-1
442-
groupid_to_code = arange(uniques_ak.size, dtype=int64)
443-
444-
# Work around to account GroupBy not sorting Categorical properly
445-
if isinstance(arr, Categorical):
446-
perm = uniques_ak.argsort()
447-
# Inverse argsort:
448-
groupid_to_code[perm] = arange(uniques_ak.size, dtype=int64)
449-
uniques_ak = uniques_ak[perm]
450-
451-
else:
452-
# First-appearance order
453-
_keys, first_idx_per_group = g.min(arange(arr_nn.size, dtype=int64))
454-
order = argsort(first_idx_per_group)
442+
# First-appearance order
443+
_keys, first_idx_per_group = g.min(arange(arr_nn.size, dtype=int64))
444+
order = argsort(first_idx_per_group)
455445

456-
# Reorder uniques by first appearance
457-
uniques_ak = uniques_ak[order]
446+
# Reorder uniques by first appearance
447+
uniques_ak = uniques_ak[order]
458448

459-
# Map group_id -> code in first-appearance order
460-
groupid_to_code = zeros(order.size, dtype=int64)
461-
groupid_to_code[order] = arange(order.size, dtype=int64)
449+
# Map group_id -> code in first-appearance order
450+
groupid_to_code = zeros(order.size, dtype=int64)
451+
groupid_to_code[order] = arange(order.size, dtype=int64)
462452

463453
# Per-element codes on the non-NA slice
464454
codes_nn = g.broadcast(groupid_to_code)
@@ -468,7 +458,9 @@ def factorize( # type: ignore[override]
468458
codes_ak = full(n, sentinel, dtype=int64)
469459
codes_ak[non_na] = codes_nn
470460

471-
return ArkoudaArray(codes_ak), type(self)(uniques_ak)
461+
codes_np = codes_ak.to_ndarray().astype(np.intp, copy=False)
462+
463+
return codes_np, type(self)(uniques_ak)
472464

473465
# In each EA
474466
def _values_for_factorize(self):
@@ -527,42 +519,45 @@ def to_ndarray(self) -> np.ndarray:
527519
"""
528520
return self._data.to_ndarray()
529521

530-
def argsort( # type: ignore[override]
522+
def argsort(
531523
self,
532524
*,
533525
ascending: bool = True,
534-
kind="quicksort",
535-
na_position: str = "last",
536-
**kwargs,
537-
) -> pdarray:
526+
kind: str = "quicksort",
527+
**kwargs: object,
528+
) -> NDArray[np.intp]:
538529
"""
539530
Return the indices that would sort the array.
540531
541-
This method computes the permutation indices that would sort the
542-
underlying Arkouda data. It aligns with the pandas ``ExtensionArray``
543-
contract, returning a 1-D ``pdarray`` of integer indices suitable for
544-
reordering the array via ``take`` or ``iloc``. NaN values are placed
545-
either at the beginning or end of the result depending on
546-
``na_position``.
532+
This method computes the permutation indices that would sort the underlying
533+
Arkouda data and returns them as a NumPy array, in accordance with the
534+
pandas ``ExtensionArray`` contract. The indices can be used to reorder the
535+
array via ``take`` or ``iloc``.
536+
537+
For floating-point data, ``NaN`` values are handled according to the
538+
``na_position`` keyword argument.
547539
548540
Parameters
549541
----------
550542
ascending : bool, default True
551-
If True, sort values in ascending order. If False, sort in
552-
descending order.
543+
If True, sort values in ascending order. If False, sort in descending
544+
order.
553545
kind : str, default "quicksort"
554-
Sorting algorithm. Present for API compatibility with NumPy and
555-
pandas but currently ignored.
556-
na_position : {"first", "last"}, default "last"
557-
Where to place NaN values in the sorted result. Currently only implemented for pdarray.
558-
For Strings and Categorical will have no effect.
559-
**kwargs : Any
560-
Additional keyword arguments for compatibility; ignored.
546+
Sorting algorithm. Present for API compatibility with NumPy and pandas
547+
but currently ignored.
548+
**kwargs
549+
Additional keyword arguments for compatibility. Supported keyword:
550+
551+
* ``na_position`` : {"first", "last"}, default "last"
552+
Where to place ``NaN`` values in the sorted result. This option is
553+
currently only applied for floating-point ``pdarray`` data; for
554+
``Strings`` and ``Categorical`` data it has no effect.
561555
562556
Returns
563557
-------
564-
pdarray
565-
Integer indices (``int64``) that would sort the array.
558+
numpy.ndarray
559+
A 1D NumPy array of dtype ``np.intp`` containing the indices that would
560+
sort the array.
566561
567562
Raises
568563
------
@@ -573,21 +568,22 @@ def argsort( # type: ignore[override]
573568
574569
Notes
575570
-----
576-
- Supports Arkouda ``pdarray``, ``Strings``, and ``Categorical`` data.
577-
- Floating-point arrays have NaNs repositioned according to
571+
* Supports Arkouda ``pdarray``, ``Strings``, and ``Categorical`` data.
572+
* For floating-point arrays, ``NaN`` values are repositioned according to
578573
``na_position``.
579-
- This method does not move data to the client; the computation
580-
occurs on the Arkouda server.
574+
* The sorting computation occurs on the Arkouda server, but the resulting
575+
permutation indices are materialized on the client as a NumPy array, as
576+
required by pandas internals.
581577
582578
Examples
583579
--------
584580
>>> import arkouda as ak
585581
>>> from arkouda.pandas.extension import ArkoudaArray
586582
>>> a = ArkoudaArray(ak.array([3.0, float("nan"), 1.0]))
587583
>>> a.argsort() # NA last by default
588-
array([2 0 1])
584+
array([2, 0, 1])
589585
>>> a.argsort(na_position="first")
590-
array([1 2 0])
586+
array([1, 2, 0])
591587
"""
592588
from arkouda.numpy import argsort
593589
from arkouda.numpy.numeric import isnan as ak_isnan
@@ -596,6 +592,9 @@ def argsort( # type: ignore[override]
596592
from arkouda.numpy.util import is_float
597593
from arkouda.pandas.categorical import Categorical
598594

595+
# Extract na_position from kwargs
596+
na_position = kwargs.pop("na_position", "last")
597+
599598
if na_position not in {"first", "last"}:
600599
raise ValueError("na_position must be 'first' or 'last'.")
601600

@@ -613,7 +612,7 @@ def argsort( # type: ignore[override]
613612
else:
614613
raise TypeError(f"Unsupported argsort dtype: {type(self._data)}")
615614

616-
return perm
615+
return perm.to_ndarray()
617616

618617
def broadcast_arrays(self, *arrays):
619618
raise NotImplementedError(

0 commit comments

Comments
 (0)