Skip to content

Commit b7907e5

Browse files
authored
Register gpu (#20)
1 parent fe78a80 commit b7907e5

File tree

4 files changed

+223
-86
lines changed

4 files changed

+223
-86
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* @maltekuehl

pyproject.toml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ packages = ["spatiomic"]
77

88
[project]
99
name = "spatiomic"
10-
version = "0.8.0"
10+
version = "0.9.0"
1111
description = "A python toolbox for spatial omics analysis."
1212
requires-python = ">=3.11"
1313
license = { file = "LICENSE" }
@@ -206,18 +206,18 @@ dev = [
206206

207207
[project.optional-dependencies]
208208
cellpose = ["cellpose>=4.0.1,<5"]
209-
spatialdata = ["spatialdata==0.5.0"]
209+
spatialdata = ["spatialdata==0.6.1"]
210210
cuda-12 = [
211-
"cuml-cu12>=24.6.0",
212-
"cugraph-cu12>=24.6.0",
213-
"nx-cugraph-cu12>=24.6.0",
214-
"cucim-cu12>=24.6.0",
211+
"cuml-cu12>=24.10.0",
212+
"cugraph-cu12>=24.10.0",
213+
"nx-cugraph-cu12>=24.10.0",
214+
"cucim-cu12>=24.10.0",
215215
]
216216
cuda-11 = [
217-
"cuml-cu11>=24.6.0",
218-
"cugraph-cu11>=24.6.0",
219-
"nx-cugraph-cu11>=24.6.0",
220-
"cucim-cu11>=24.6.0",
217+
"cuml-cu11>=24.10.0",
218+
"cugraph-cu11>=24.10.0",
219+
"nx-cugraph-cu11>=24.10.0",
220+
"cucim-cu11>=24.10.0",
221221
]
222222

223223
[tool.uv]

spatiomic/process/_register.py

Lines changed: 121 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,122 @@
1111
class Register:
1212
"""Expose registration methods."""
1313

14+
@staticmethod
15+
def _preprocess_images(
16+
pixels: NDArray,
17+
reference_pixels: NDArray,
18+
blur: bool = False,
19+
match_histogram: bool = False,
20+
threshold: bool = False,
21+
threshold_percentile: Union[int, float] = 70,
22+
use_gpu: bool = True,
23+
) -> Tuple[NDArray, NDArray]:
24+
"""Preprocess images with optional blur, histogram matching, and thresholding.
25+
26+
Args:
27+
pixels (NDArray): The pixels to preprocess.
28+
reference_pixels (NDArray): The reference pixels to preprocess.
29+
blur (bool, optional): Whether to apply Gaussian blur. Defaults to False.
30+
match_histogram (bool, optional): Whether to match histograms. Defaults to False.
31+
threshold (bool, optional): Whether to apply thresholding. Defaults to False.
32+
threshold_percentile (Union[int, float], optional): Percentile for thresholding. Defaults to 70.
33+
use_gpu (bool, optional): Whether to use GPU acceleration. Defaults to True.
34+
35+
Returns:
36+
Tuple[NDArray, NDArray]: The preprocessed pixels and reference pixels.
37+
"""
38+
if use_gpu:
39+
try:
40+
import cupy as cp # type: ignore
41+
42+
pixels_gpu = cp.array(pixels)
43+
reference_pixels_gpu = cp.array(reference_pixels)
44+
45+
if blur:
46+
from cucim.skimage.filters import gaussian # type: ignore
47+
48+
pixels_gpu = gaussian(pixels_gpu)
49+
reference_pixels_gpu = gaussian(reference_pixels_gpu)
50+
51+
if match_histogram:
52+
from cucim.skimage.exposure import match_histograms # type: ignore
53+
54+
pixels_gpu = match_histograms(pixels_gpu, reference_pixels_gpu)
55+
56+
if threshold:
57+
threshold_limit = cp.percentile(reference_pixels_gpu, threshold_percentile)
58+
reference_pixels_gpu = cp.where(reference_pixels_gpu < threshold_limit, 0, reference_pixels_gpu)
59+
pixels_gpu = cp.where(pixels_gpu < threshold_limit, 0, pixels_gpu)
60+
61+
return pixels_gpu.get(), reference_pixels_gpu.get() # type: ignore
62+
except Exception:
63+
use_gpu = False
64+
65+
if blur:
66+
from skimage.filters import gaussian
67+
68+
pixels = gaussian(pixels)
69+
reference_pixels = gaussian(reference_pixels)
70+
71+
if match_histogram:
72+
from skimage.exposure import match_histograms
73+
74+
pixels = match_histograms(pixels, reference_pixels)
75+
76+
if threshold:
77+
threshold_limit = np.percentile(reference_pixels, threshold_percentile)
78+
reference_pixels = np.where(reference_pixels < threshold_limit, 0, reference_pixels)
79+
pixels = np.where(pixels < threshold_limit, 0, pixels)
80+
81+
return pixels, reference_pixels
82+
1483
@staticmethod
1584
def get_ssim(
1685
pixels: NDArray,
1786
reference_pixels: NDArray,
87+
use_gpu: bool = True,
1888
) -> float:
1989
"""Calculate the structural similarity index measure.
2090
2191
Args:
2292
pixels (NDArray): A 2D array of pixels.
2393
reference_pixels (NDArray): The 2D reference array for calculation of the structural similarity.
94+
use_gpu (bool, optional): Whether to use the cucim GPU implementation. Defaults to True.
2495
2596
Returns:
2697
float: The structural similarity index measure.
2798
"""
99+
if use_gpu:
100+
try:
101+
import cupy as cp # type: ignore
102+
from cucim.skimage.metrics import ( # type: ignore
103+
structural_similarity,
104+
)
105+
106+
pixels = cp.array(pixels)
107+
reference_pixels = cp.array(reference_pixels)
108+
data_range = float(cp.max(pixels) - cp.min(pixels))
109+
110+
ssim = structural_similarity(
111+
pixels,
112+
reference_pixels,
113+
full=False,
114+
data_range=data_range,
115+
)
116+
117+
return float(ssim.get()) # type: ignore
118+
except Exception:
119+
use_gpu = False
120+
28121
from skimage.metrics import structural_similarity
29122

123+
data_range = float(np.max(pixels) - np.min(pixels))
124+
30125
ssim = structural_similarity(
31126
pixels,
32127
reference_pixels,
33128
full=False,
34-
data_range=np.max(pixels) - np.min(pixels),
129+
data_range=data_range,
35130
)
36131

37132
return float(ssim)
@@ -65,30 +160,22 @@ def get_shift(
65160
Defaults to "phase_correlation".
66161
upsample_factor (int, optional): The upsample factor to use for the phase correlation method.
67162
Defaults to 1.
68-
use_gpu (bool, optional): Whether to use the cucim phase_correlation gpu implementation.
69-
Defaults to True.
163+
use_gpu (bool, optional): Whether to use cucim GPU implementations. Defaults to True.
70164
71165
Returns:
72166
Tuple[float, float]: The offset on the y- and the x-axis.
73167
"""
74-
if blur:
75-
from skimage.filters import gaussian
76-
77-
pixels = gaussian(pixels)
78-
reference_pixels = gaussian(reference_pixels)
79-
80-
if match_histogram:
81-
from skimage.exposure import match_histograms
82-
83-
pixels = match_histograms(pixels, reference_pixels)
84-
85-
if threshold:
86-
threshold_limit = np.percentile(reference_pixels, threshold_percentile)
87-
reference_pixels[reference_pixels < threshold_limit] = 0
88-
pixels[pixels < threshold_limit] = 0
168+
pixels, reference_pixels = cls._preprocess_images(
169+
pixels,
170+
reference_pixels,
171+
blur=blur,
172+
match_histogram=match_histogram,
173+
threshold=threshold,
174+
threshold_percentile=threshold_percentile,
175+
use_gpu=use_gpu,
176+
)
89177

90178
if method == "chi2_shift":
91-
# requires image_registration and typing_extensions>=3.10.0.1 and fftw for best performance
92179
from image_registration.chi2_shifts import chi2_shift # type: ignore
93180

94181
shift = chi2_shift(
@@ -105,17 +192,18 @@ def get_shift(
105192
(offset_y, offset_x) = cls.get_phase_shift(
106193
pixels=pixels,
107194
reference_pixels=reference_pixels,
108-
blur=blur,
109-
match_histogram=match_histogram,
110-
threshold=threshold,
195+
blur=False,
196+
match_histogram=False,
197+
threshold=False,
111198
use_gpu=use_gpu,
112199
upsample_factor=upsample_factor,
113200
)
114201

115202
return (offset_y, offset_x)
116203

117-
@staticmethod
204+
@classmethod
118205
def get_phase_shift(
206+
cls,
119207
pixels: NDArray,
120208
reference_pixels: NDArray,
121209
blur: bool = False,
@@ -137,27 +225,20 @@ def get_phase_shift(
137225
percentile of the reference. Defaults to False.
138226
upsample_factor (int, optional): The upsample factor to use for the phase correlation method.
139227
Defaults to 1.
140-
use_gpu (bool, optional): Whether to use the cucim phase_correlation gpu implementation instead of chi2
141-
shift. Defaults to True.
228+
use_gpu (bool, optional): Whether to use cucim GPU implementations. Defaults to True.
142229
143230
Returns:
144231
Tuple[float, float]: The offset on the y- and the x-axis.
145232
"""
146-
if blur:
147-
from skimage.filters import gaussian
148-
149-
pixels = gaussian(pixels)
150-
reference_pixels = gaussian(reference_pixels)
151-
152-
if match_histogram:
153-
from skimage.exposure import match_histograms
154-
155-
pixels = match_histograms(pixels, reference_pixels)
156-
157-
if threshold:
158-
threshold_limit = np.percentile(reference_pixels, 70) # type: ignore
159-
reference_pixels[reference_pixels < threshold_limit] = 0
160-
pixels[pixels < threshold_limit] = 0
233+
pixels, reference_pixels = cls._preprocess_images(
234+
pixels,
235+
reference_pixels,
236+
blur=blur,
237+
match_histogram=match_histogram,
238+
threshold=threshold,
239+
threshold_percentile=70,
240+
use_gpu=use_gpu,
241+
)
161242

162243
if use_gpu:
163244
try:

0 commit comments

Comments
 (0)