Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions arkouda/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,18 @@ def apply(
if result_type != arr.dtype:
raise TypeError("result_dtype must match the dtype of the input")

repMsg = generic_msg(
rep_msg = generic_msg(
cmd=f"applyStr<{arr.dtype},{arr.ndim}>",
args={"x": arr, "funcStr": func},
)
return create_pdarray(repMsg)
return create_pdarray(rep_msg)
elif callable(func):
pickleData = cloudpickle.dumps(func)
pickleDataStr = base64.b64encode(pickleData).decode("utf-8")
repMsg = generic_msg(
pickle_data = cloudpickle.dumps(func)
pickle_data_str = base64.b64encode(pickle_data).decode("utf-8")
rep_msg = generic_msg(
cmd=f"applyPickle<{arr.dtype},{arr.ndim},{result_type}>",
args={"x": arr, "pickleData": pickleDataStr},
args={"x": arr, "pickleData": pickle_data_str},
)
return create_pdarray(repMsg)
return create_pdarray(rep_msg)
else:
raise TypeError("func must be a string or a callable function")
34 changes: 20 additions & 14 deletions arkouda/array_api/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np

import arkouda as ak

from arkouda import dtype as akdtype

from ._typing import Dtype


__all__: list[str] = []

Expand All @@ -22,21 +26,23 @@
complex128 = ak.complex128
bool_ = ak.bool_

_all_dtypes = (
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
complex64,
complex128,
bool,

_all_dtypes: tuple[Dtype, ...] = (
np.dtype(np.int8),
np.dtype(np.int16),
np.dtype(np.int32),
np.dtype(np.int64),
np.dtype(np.uint8),
np.dtype(np.uint16),
np.dtype(np.uint32),
np.dtype(np.uint64),
np.dtype(np.float32),
np.dtype(np.float64),
np.dtype(np.complex64),
np.dtype(np.complex128),
np.dtype(np.bool_),
)

_boolean_dtypes = (bool,)
_real_floating_dtypes = (float32, float64)
_floating_dtypes = (float32, float64, complex64, complex128)
Expand Down
35 changes: 19 additions & 16 deletions arkouda/array_api/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

from __future__ import annotations

from typing import Any, Literal, Protocol, TypeVar, Union
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, TypeVar, Union

from numpy import dtype, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64
import numpy as np

from .array_object import Array

if TYPE_CHECKING:
from .array_object import Array


__all__ = [
Expand All @@ -37,19 +39,20 @@ def __len__(self, /) -> int: ...
Device = Literal["cpu"]


Dtype = dtype[
Union[
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
]
Dtype: TypeAlias = Union[
np.dtype[np.int8],
np.dtype[np.int16],
np.dtype[np.int32],
np.dtype[np.int64],
np.dtype[np.uint8],
np.dtype[np.uint16],
np.dtype[np.uint32],
np.dtype[np.uint64],
np.dtype[np.float32],
np.dtype[np.float64],
np.dtype[np.complex64],
np.dtype[np.complex128],
np.dtype[np.bool_],
]


Expand Down
2 changes: 1 addition & 1 deletion arkouda/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def eye(
raise ValueError(f"Unsupported device {device!r}")

if M is None:
M = N
M = N # noqa: N806

from arkouda import dtype as akdtype

Expand Down
42 changes: 24 additions & 18 deletions arkouda/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,28 @@ def isdtype(dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]
@implements_numpy(np.result_type)
def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
"""Compute the result dtype for a group of arrays and/or dtypes."""
A = []
for a in arrays_and_dtypes:
if isinstance(a, Array):
a = a.dtype
elif isinstance(a, np.ndarray):
a = a.dtype
elif a not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
A.append(a)

if len(A) == 0:
dtypes: list[Dtype] = []

for obj in arrays_and_dtypes:
if isinstance(obj, Array):
dt: Dtype = obj.dtype
elif isinstance(obj, np.ndarray):
# If you truly allow numpy arrays here, you may need a mapping step.
# If Array.dtype already returns a numpy dtype, this may be fine.
dt = obj.dtype # type: ignore[assignment]
else:
dt = obj
if dt not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")

dtypes.append(dt)

if len(dtypes) == 0:
raise ValueError("at least one array or dtype is required")
elif len(A) == 1:
return A[0]
else:
t = A[0]
for t2 in A[1:]:
t = _result_type(t, t2)
return t
if len(dtypes) == 1:
return dtypes[0]

t = dtypes[0]
for t2 in dtypes[1:]:
t = _result_type(t, t2)
return t
20 changes: 10 additions & 10 deletions arkouda/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def broadcast_arrays(*arrays: Array) -> List[Array]:
from arkouda.numpy.util import broadcast_shapes

shapes = [a.shape for a in arrays]
bcShape = broadcast_shapes(*shapes)
bc_shape = broadcast_shapes(*shapes)

return [broadcast_to(a, shape=bcShape) for a in arrays]
return [broadcast_to(a, shape=bc_shape) for a in arrays]


@implements_numpy(np.broadcast_to)
Expand Down Expand Up @@ -293,9 +293,9 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
from arkouda.client import generic_msg
from arkouda.numpy.util import _axis_validation

axisList = []
axis_list = []
if axis is not None:
valid, axisList = _axis_validation(axis, x.ndim)
valid, axis_list = _axis_validation(axis, x.ndim)
if not valid:
raise IndexError(f"{axis} is not a valid axis/axes for array of rank {x.ndim}")

Expand All @@ -312,8 +312,8 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
),
args={
"name": x._array,
"nAxes": len(axisList),
"axis": axisList,
"nAxes": len(axis_list),
"axis": axis_list,
},
),
)
Expand Down Expand Up @@ -610,9 +610,9 @@ def roll(
if isinstance(shift, tuple) and isinstance(axis, tuple) and (len(axis) != len(shift)):
raise IndexError("When shift and axis are both tuples, they must have the same length.")

axisList = []
axis_list = []
if axis is not None:
valid, axisList = _axis_validation(axis, x.ndim)
valid, axis_list = _axis_validation(axis, x.ndim)
if not valid:
raise IndexError(f"{axis} is not a valid axis/axes for array of rank {x.ndim}")

Expand All @@ -631,8 +631,8 @@ def roll(
"name": x._array,
"nShifts": len(shift) if isinstance(shift, tuple) else 1,
"shift": (list(shift) if isinstance(shift, tuple) else [shift]),
"nAxes": len(axisList),
"axis": axisList,
"nAxes": len(axis_list),
"axis": axis_list,
},
),
)
Expand Down
8 changes: 4 additions & 4 deletions arkouda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def _mem_get_factor(unit: str) -> int:
)


logger = getArkoudaLogger(name="Arkouda Client", logLevel=LogLevel.INFO)
clientLogger = getArkoudaLogger(name="Arkouda User Logger", logFormat="%(message)s")
logger = getArkoudaLogger(name="Arkouda Client", log_level=LogLevel.INFO)
clientLogger = getArkoudaLogger(name="Arkouda User Logger", log_format="%(message)s")


class ClientMode(Enum):
Expand Down Expand Up @@ -1452,8 +1452,8 @@ def get_server_commands() -> Mapping[str, str]:

def print_server_commands():
"""Print the list of available server commands."""
cmdMap = get_server_commands()
cmds = [k for k in sorted(cmdMap.keys())]
cmd_map = get_server_commands()
cmds = [k for k in sorted(cmd_map.keys())]
sys.stdout.write(f"Total available server commands: {len(cmds)}")
for cmd in cmds:
sys.stdout.write(f"\t{cmd}")
Expand Down
Loading