Skip to content

Commit 9e3b005

Browse files
authored
Merge pull request #114 from chrishavlin/fix_ndarray_typing
improving type hints related to numpy
2 parents deb9f93 + cef0118 commit 9e3b005

File tree

5 files changed

+20
-18
lines changed

5 files changed

+20
-18
lines changed

.github/workflows/type-check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: type checking
33
on:
44
pull_request:
55
paths:
6-
- yt_experiments/**/*.py
6+
- yt_xarray/**/*.py
77
- pyproject.toml
88
- requirements/typecheck.txt
99
- .github/workflows/type-checking.yaml

yt_xarray/accessor/_xr_to_yt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def _check_grid_stretchiness(x: npt.NDArray) -> _GridType:
494494
return _GridType.STRETCHED
495495

496496

497-
def _check_for_time(dim_name, dim_vals: np.ndarray):
497+
def _check_for_time(dim_name, dim_vals: npt.NDArray) -> bool:
498498
return "time" in dim_name.lower() or type(dim_vals) is np.datetime64
499499

500500

yt_xarray/accessor/accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import xarray as xr
55
import yt
6-
from numpy.typing import ArrayLike
6+
from numpy.typing import ArrayLike, NDArray
77
from unyt import unyt_quantity
88
from yt.data_objects.static_output import Dataset as ytDataset
99

@@ -210,7 +210,7 @@ def get_bbox(
210210
field: str,
211211
sel_dict: dict[str, Any] | None = None,
212212
sel_dict_type: str = "isel",
213-
) -> np.ndarray:
213+
) -> NDArray:
214214
"""
215215
return the bounding box array for a field, with possible selections
216216

yt_xarray/utilities/_grid_decomposition.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -456,28 +456,30 @@ class ChunkInfo:
456456
def __init__(
457457
self,
458458
data_shp: Tuple[int, ...],
459-
chunksizes: npt.NDArray,
460-
starting_index_offset: npt.NDArray | None = None,
459+
chunksizes: npt.NDArray[np.int64],
460+
starting_index_offset: npt.NDArray[np.int64] | None = None,
461461
):
462462

463463
self.chunksizes = chunksizes
464-
self.data_shape = np.asarray(data_shp)
465-
self.n_chnk = self.data_shape / chunksizes # may not be int
464+
self.data_shape = np.asarray(data_shp).astype(np.int64)
465+
self.n_chnk: npt.NDArray[np.int64] = (
466+
self.data_shape / chunksizes
467+
) # may not be int
466468
self.n_whl_chnk = np.floor(self.n_chnk).astype(int) # whole chunks in each dim
467469
self.n_part_chnk = np.ceil(self.n_chnk - self.n_whl_chnk).astype(int)
468470
self.n_tots = np.prod(self.n_part_chnk + self.n_whl_chnk)
469471

470472
self.ndim = len(data_shp)
471473
if starting_index_offset is None:
472-
starting_index_offset = np.zeros(self.data_shape.shape, dtype=int)
473-
self.starting_index_offset = starting_index_offset
474+
starting_index_offset = np.zeros(self.data_shape.shape, dtype=np.int64)
475+
self.starting_index_offset: npt.NDArray[np.int64] = starting_index_offset
474476

475-
_si: List[npt.NDArray] | None = None
476-
_ei: List[npt.NDArray] | None = None
477-
_sizes: List[npt.NDArray] | None = None
477+
_si: List[npt.NDArray[np.int64]] | None = None
478+
_ei: List[npt.NDArray[np.int64]] | None = None
479+
_sizes: List[npt.NDArray[np.int64]] | None = None
478480

479481
@property
480-
def si(self) -> List[npt.NDArray]:
482+
def si(self) -> List[npt.NDArray[np.int64]]:
481483
"""
482484
The starting indices of individual chunks by dimension.
483485
Includes any global offset.
@@ -523,7 +525,7 @@ def si(self) -> List[npt.NDArray]:
523525
return self._si
524526

525527
@property
526-
def ei(self) -> List[npt.NDArray]:
528+
def ei(self) -> List[npt.NDArray[np.int64]]:
527529
"""
528530
The ending indices of individual chunks by dimension.
529531
Includes any global offset.
@@ -534,7 +536,7 @@ def ei(self) -> List[npt.NDArray]:
534536
return self._ei
535537

536538
@property
537-
def sizes(self) -> List[npt.NDArray]:
539+
def sizes(self) -> List[npt.NDArray[np.int64]]:
538540
if self._sizes is None:
539541
_ = self.si
540542
assert self._sizes is not None

yt_xarray/utilities/_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _test_time_coord(nt: int = 5) -> npt.NDArray:
176176

177177
def _get_test_coord(
178178
cname, n, minv: float | None = None, maxv: float | None = None
179-
) -> np.ndarray:
179+
) -> npt.NDArray[np.floating]:
180180
if cname in known_coord_aliases:
181181
cname = known_coord_aliases[cname]
182182

@@ -211,7 +211,7 @@ def construct_ds_with_extra_dim(
211211
ncoords: int | None = None,
212212
nd_space: int = 3,
213213
reverse_indices: list[int] | None = None,
214-
):
214+
) -> xr.Dataset:
215215
coord_configs = {
216216
0: (dim_name, "x", "y", "z"),
217217
1: (dim_name, "z", "y", "x"),

0 commit comments

Comments
 (0)