diff --git a/arkouda/numpy/pdarrayclass.py b/arkouda/numpy/pdarrayclass.py index c4eb8a2c972..9b224a838da 100644 --- a/arkouda/numpy/pdarrayclass.py +++ b/arkouda/numpy/pdarrayclass.py @@ -6,7 +6,7 @@ from functools import reduce from math import ceil from sys import modules -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast, overload +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union, cast, overload import numpy as np @@ -2644,7 +2644,17 @@ def bigint_to_uint_arrays(self) -> List[pdarray]: ret_list = json.loads(generic_msg(cmd=cmd, args={"array": self})) return list(reversed([create_pdarray(a) for a in ret_list])) - def reshape(self, *shape): + @overload + def reshape(self, shape: Sequence[int_scalars]) -> pdarray: ... + + @overload + def reshape(self, *shape: int_scalars) -> pdarray: ... + + @overload + def reshape(self, shape: pdarray) -> pdarray: ... + + @typechecked + def reshape(self, *shape: Union[int_scalars, Sequence[int_scalars], pdarray]) -> pdarray: """ Gives a new shape to an array without changing its data. @@ -2678,21 +2688,40 @@ def reshape(self, *shape): # For example, a.reshape(10, 11) is equivalent to a.reshape((10, 11)) # the lenshape variable addresses an error that occurred when a single integer was # passed + from typing import get_args + from arkouda.client import generic_msg + shape_seq: Sequence[int_scalars] + if len(shape) == 1: - shape = shape[0] - lenshape = 1 - if (not isinstance(shape, int)) and (not isinstance(shape, pdarray)): - shape = [i for i in shape] - lenshape = len(shape) + arg = shape[0] + + if isinstance(arg, get_args(int_scalars)): + shape_seq = (arg,) + + elif isinstance(arg, Sequence): + shape_seq = tuple(arg) + + elif isinstance(arg, pdarray): + shape_seq = cast(Sequence[int_scalars], arg.tolist()) + + else: + raise TypeError(f"Invalid shape argument {shape}") + + else: + shape_seq = cast(Sequence[int_scalars], shape) + + shape_arg: list[int_scalars] = list(shape_seq) + + lenshape = len(shape_arg) return create_pdarray( generic_msg( cmd=f"reshape<{self.dtype},{self.ndim},{lenshape}>", args={ "name": self.name, - "shape": shape, + "shape": shape_arg, }, ), max_bits=self.max_bits, diff --git a/tests/numpy/pdarrayclass_test.py b/tests/numpy/pdarrayclass_test.py index 830bf641128..316764c327f 100644 --- a/tests/numpy/pdarrayclass_test.py +++ b/tests/numpy/pdarrayclass_test.py @@ -143,7 +143,7 @@ def test_flatten(self, size, dtype): def test_flatten_multidim(self, size, dtype): size = size - (size % 4) a = ak.arange(size, dtype=dtype) - b = a.reshape((2, 2, size / 4)) + b = a.reshape((2, 2, size // 4)) ak_assert_equal(b.flatten(), a) @pytest.mark.parametrize("size", pytest.prob_size)