Skip to content

Commit 5ade8cc

Browse files
authored
Merge pull request #35 from HiDiHlabs/shape-typing
Shape typing for numpy
2 parents 818558e + 8037157 commit 5ade8cc

File tree

9 files changed

+171
-81
lines changed

9 files changed

+171
-81
lines changed

sainsc/_typealias.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import TypeAlias
33

44
import numpy as np
5-
from numpy.typing import NDArray
65
from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix
76

87
_PathLike: TypeAlias = os.PathLike[str] | str
@@ -15,7 +14,20 @@
1514
_RangeTuple: TypeAlias = tuple[int, int]
1615
_RangeTuple2D: TypeAlias = tuple[_RangeTuple, _RangeTuple]
1716

18-
_Local_Max: TypeAlias = tuple[NDArray[np.int_], NDArray[np.int_]]
17+
18+
_Local_Max: TypeAlias = tuple[
19+
np.ndarray[tuple[int], np.dtype[np.int_]], np.ndarray[tuple[int], np.dtype[np.int_]]
20+
]
1921

2022
_Color: TypeAlias = str | tuple[float, float, float]
2123
_Cmap: TypeAlias = str | list[_Color] | dict[str, _Color]
24+
25+
26+
_Coord: TypeAlias = np.ndarray[tuple[int], np.dtype[np.int32]]
27+
_CosineMap: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.float32]]
28+
_AssignmentScoreMap: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.float32]]
29+
_Kernel: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.float32]]
30+
_SignatureArray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.float32]]
31+
_Background: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.bool_]]
32+
_CountMap: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.unsignedinteger]]
33+
_KDE: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.single]]

sainsc/_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import pandas as pd
7-
from numpy.typing import NDArray
87

98
from ._utils_rust import coordinate_as_string
109

@@ -38,15 +37,18 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
3837
return wrapper
3938

4039

40+
N = TypeVar("N", bound=int)
41+
42+
4143
def _get_coordinate_index(
42-
x: NDArray[np.integer],
43-
y: NDArray[np.integer],
44+
x: np.ndarray[tuple[N], np.dtype[np.integer]],
45+
y: np.ndarray[tuple[N], np.dtype[np.integer]],
4446
*,
4547
name: str | None = None,
4648
n_threads: int | None = None,
4749
) -> pd.Index:
48-
x_i32: NDArray[np.int32] = x.astype(np.int32, copy=False)
49-
y_i32: NDArray[np.int32] = y.astype(np.int32, copy=False)
50+
x_i32: np.ndarray[tuple[N], np.dtype[np.int32]] = x.astype(np.int32, copy=False)
51+
y_i32: np.ndarray[tuple[N], np.dtype[np.int32]] = y.astype(np.int32, copy=False)
5052

5153
return pd.Index(
5254
coordinate_as_string(x_i32, y_i32, n_threads=n_threads), dtype=str, name=name

sainsc/_utils_rust.pyi

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from typing import Self
22

33
import numpy as np
4-
from numpy.typing import NDArray
54
from polars import DataFrame
65

7-
from ._typealias import _Csx, _CsxArray
6+
from ._typealias import (
7+
_AssignmentScoreMap,
8+
_Coord,
9+
_CosineMap,
10+
_Csx,
11+
_CsxArray,
12+
_Kernel,
13+
_SignatureArray,
14+
)
815

916
def sparse_kde_csx_py(
10-
counts: _Csx, kernel: NDArray[np.float32], *, threshold: float = 0
17+
counts: _Csx, kernel: _Kernel, *, threshold: float = 0
1118
) -> _CsxArray:
1219
"""
1320
Calculate the KDE for each spot with counts as uint16.
@@ -16,8 +23,11 @@ def sparse_kde_csx_py(
1623
def kde_at_coord(
1724
counts: GridCounts,
1825
genes: list[str],
19-
kernel: NDArray[np.float32],
20-
coordinates: tuple[NDArray[np.int_], NDArray[np.int_]],
26+
kernel: _Kernel,
27+
coordinates: tuple[
28+
np.ndarray[tuple[int], np.dtype[np.int_]],
29+
np.ndarray[tuple[int], np.dtype[np.int_]],
30+
],
2131
*,
2232
n_threads: int | None = None,
2333
) -> _CsxArray:
@@ -26,29 +36,34 @@ def kde_at_coord(
2636
"""
2737

2838
def categorical_coordinate(
29-
x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int | None = None
30-
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
39+
x: _Coord, y: _Coord, *, n_threads: int | None = None
40+
) -> tuple[
41+
np.ndarray[tuple[int], np.dtype[np.int32]],
42+
np.ndarray[tuple[int, int], np.dtype[np.int32]],
43+
]:
3144
"""
3245
Get the codes and the coordinates (comparable to a pandas.Categorical)
3346
"""
3447

3548
def coordinate_as_string(
36-
x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int | None = None
37-
) -> NDArray[np.str_]:
49+
x: _Coord, y: _Coord, *, n_threads: int | None = None
50+
) -> np.ndarray[tuple[int], np.dtype[np.str_]]:
3851
"""
3952
Concatenate two int arrays elementwise into a string representation (i.e. 'x_y').
4053
"""
4154

4255
def cosinef32_and_celltypei8(
4356
counts: GridCounts,
4457
genes: list[str],
45-
signatures: NDArray[np.float32],
46-
kernel: NDArray[np.float32],
58+
signatures: _SignatureArray,
59+
kernel: _Kernel,
4760
*,
4861
log: bool = False,
4962
chunk_size: tuple[int, int] = (500, 500),
5063
n_threads: int | None = None,
51-
) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.int8]]:
64+
) -> tuple[
65+
_CosineMap, _AssignmentScoreMap, np.ndarray[tuple[int, int], np.dtype[np.int8]]
66+
]:
5267
"""
5368
Calculate the cosine similarity given counts and signatures and assign the most
5469
similar celltype.
@@ -57,13 +72,15 @@ def cosinef32_and_celltypei8(
5772
def cosinef32_and_celltypei16(
5873
counts: GridCounts,
5974
genes: list[str],
60-
signatures: NDArray[np.float32],
61-
kernel: NDArray[np.float32],
75+
signatures: _SignatureArray,
76+
kernel: _Kernel,
6277
*,
6378
log: bool = False,
6479
chunk_size: tuple[int, int] = (500, 500),
6580
n_threads: int | None = None,
66-
) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.int16]]:
81+
) -> tuple[
82+
_CosineMap, _AssignmentScoreMap, np.ndarray[tuple[int, int], np.dtype[np.int16]]
83+
]:
6784
"""
6885
Calculate the cosine similarity given counts and signatures and assign the most
6986
similar celltype.
@@ -186,7 +203,7 @@ class GridCounts:
186203
Mapping from gene to number of counts.
187204
"""
188205

189-
def grid_counts(self) -> NDArray[np.uintc]:
206+
def grid_counts(self) -> np.ndarray[tuple[int, int], np.dtype[np.uintc]]:
190207
"""
191208
Counts per pixel.
192209
@@ -231,7 +248,7 @@ class GridCounts:
231248
Range to crop as `(ymin, ymax)`
232249
"""
233250

234-
def filter_mask(self, mask: NDArray[np.bool_]):
251+
def filter_mask(self, mask: np.ndarray[tuple[int, int], np.dtype[np.bool_]]):
235252
"""
236253
Filter all genes with a binary mask.
237254

sainsc/datasets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import TYPE_CHECKING
24

35
import pandas as pd
@@ -13,7 +15,7 @@
1315
version = "v" + __version__
1416

1517

16-
def _get_signature_pooch() -> "Pooch":
18+
def _get_signature_pooch() -> Pooch:
1719
# use indirection to enable pooch as optional dependency w/o lazy loading
1820
try:
1921
import pooch

sainsc/io/_io.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from collections.abc import Collection
24
from pathlib import Path
35
from typing import TYPE_CHECKING, Literal, get_args
@@ -457,7 +459,7 @@ def read_StereoSeq_bins(
457459
sep: str = "\t",
458460
n_threads: int | None = None,
459461
**kwargs,
460-
) -> "AnnData | SpatialData":
462+
) -> AnnData | SpatialData:
461463
"""
462464
Read a Stereo-seq GEM file into bins.
463465
@@ -521,7 +523,10 @@ def read_StereoSeq_bins(
521523

522524
obs = pd.DataFrame(
523525
index=_get_coordinate_index(
524-
coordinates[:, 0], coordinates[:, 1], name="bin", n_threads=n_threads
526+
coordinates[:, 0], # type: ignore
527+
coordinates[:, 1], # type: ignore
528+
name="bin",
529+
n_threads=n_threads,
525530
)
526531
)
527532

sainsc/io/_io_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
import numpy as np
77
import polars as pl
8-
from numpy.typing import NDArray
98

10-
from .._typealias import _PathLike
9+
from .._typealias import _Coord, _PathLike
1110
from .._utils_rust import categorical_coordinate
1211

1312

@@ -20,8 +19,11 @@ def _bin_coordinates(df: pl.DataFrame, bin_size: float) -> pl.DataFrame:
2019

2120

2221
def _categorical_coordinate(
23-
x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int | None = None
24-
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
22+
x: _Coord, y: _Coord, *, n_threads: int | None = None
23+
) -> tuple[
24+
np.ndarray[tuple[int], np.dtype[np.int32]],
25+
np.ndarray[tuple[int, int], np.dtype[np.int32]],
26+
]:
2527
assert len(x) == len(y)
2628

2729
return categorical_coordinate(x, y, n_threads=n_threads)

0 commit comments

Comments
 (0)