Skip to content

Commit 7cc1349

Browse files
committed
feat: add simple patching func
1 parent 630a866 commit 7cc1349

File tree

3 files changed

+164
-16
lines changed

3 files changed

+164
-16
lines changed

cellseg_models_pytorch/datasets/dataset_writers/_base_writer.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathos.multiprocessing import ThreadPool as Pool
77

88
from ...transforms.albu_transforms import IMG_TRANSFORMS, compose
9-
from ...utils import FileHandler, TilerStitcher, fix_duplicates, remap_label
9+
from ...utils import FileHandler, fix_duplicates, get_patches, remap_label
1010

1111
__all__ = ["BaseWriter"]
1212

@@ -184,11 +184,7 @@ def _get_tiles(
184184
masks: Union[Dict[str, np.ndarray], None] = None,
185185
) -> Tuple[Dict[str, np.ndarray], Union[Dict[str, np.ndarray], None]]:
186186
"""Do tiling on an image and corresponding masks if there are such."""
187-
# Init Tilers
188-
im_tiler = TilerStitcher(
189-
im_shape=im.shape, patch_shape=self.patch_size + (3,), stride=self.stride
190-
)
191-
im_tiles = im_tiler.patch(im)
187+
im_tiles = get_patches(im, self.stride, self.patch_size)[0]
192188

193189
# Tile masks if there are such.
194190
mask_tiles = None
@@ -204,18 +200,18 @@ def _get_tiles(
204200
if "sem_map" in masks.keys():
205201
sem = masks["sem_map"]
206202

207-
mask_tiler = TilerStitcher(
208-
im_shape=inst.shape,
209-
patch_shape=self.patch_size + (1,),
210-
stride=self.stride,
211-
)
212-
213203
if inst is not None:
214-
mask_tiles["inst_map"] = mask_tiler.patch(inst).squeeze()
204+
mask_tiles["inst_map"] = get_patches(
205+
inst, self.stride, self.patch_size
206+
)[0]
215207
if types is not None:
216-
mask_tiles["type_map"] = mask_tiler.patch(types).squeeze()
208+
mask_tiles["type_map"] = get_patches(
209+
types, self.stride, self.patch_size
210+
)[0]
217211
if sem is not None:
218-
mask_tiles["sem_map"] = mask_tiler.patch(sem).squeeze()
212+
mask_tiles["sem_map"] = get_patches(sem, self.stride, self.patch_size)[
213+
0
214+
]
219215

220216
return im_tiles, mask_tiles
221217

cellseg_models_pytorch/utils/patching.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Dict, Tuple
22

33
import numpy as np
44
import torch
@@ -10,6 +10,7 @@
1010
"extract_patches_numpy",
1111
"stitch_patches_numpy",
1212
"TilerStitcher",
13+
"get_patches",
1314
]
1415

1516

@@ -246,6 +247,143 @@ def stitch_patches_torch(
246247
return output
247248

248249

250+
def _get_margins_and_pad(
251+
first_endpoint: int, img_size: int, stride: int, pad: int = None
252+
) -> Tuple[int, int]:
253+
"""Get the number of slices needed for one direction and the overlap."""
254+
pad = int(pad) if pad is not None else 20 # at least some padding needed
255+
img_size += pad
256+
257+
n = 1
258+
mod = 0
259+
end = first_endpoint
260+
while True:
261+
n += 1
262+
end += stride
263+
264+
if end > img_size:
265+
mod = end - img_size
266+
break
267+
elif end == img_size:
268+
break
269+
270+
return n, mod + pad
271+
272+
273+
def _get_slices(
274+
stride: int,
275+
patch_size: Tuple[int, int],
276+
img_size: Tuple[int, int],
277+
pad: int = None,
278+
) -> Tuple[Dict[str, slice], int, int]:
279+
"""Get all the overlapping slices in a dictionary and the needed paddings."""
280+
y_end, x_end = patch_size
281+
nrows, pady = _get_margins_and_pad(y_end, img_size[0], stride, pad=pad)
282+
ncols, padx = _get_margins_and_pad(x_end, img_size[1], stride, pad=pad)
283+
284+
xyslices = {}
285+
for row in range(nrows):
286+
for col in range(ncols):
287+
y_start = row * stride
288+
y_end = y_start + patch_size[0]
289+
x_start = col * stride
290+
x_end = x_start + patch_size[1]
291+
xyslices[f"y-{y_start}_x-{x_start}"] = (
292+
slice(y_start, y_end),
293+
slice(x_start, x_end),
294+
)
295+
296+
return xyslices, pady, padx, nrows, ncols
297+
298+
299+
def get_patches(
300+
arr: np.ndarray, stride: int, patch_size: Tuple[int, int], padding: int = None
301+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple[int, ...], int, int]:
302+
"""Patch an input array to overlapping or non-overlapping patches.
303+
304+
NOTE: some padding is applied by default to make the arr divisible by patch_size.
305+
306+
Parameters
307+
----------
308+
arr : np.ndarray
309+
An array of shape: (H, W, C) or (H, W).
310+
stride : int
311+
Stride of the sliding window.
312+
patch_size : Tuple[int, int]
313+
Height and width of the patch
314+
padding : int, optional
315+
Size of reflection padding.
316+
317+
Returns
318+
--------
319+
Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple[int, ...], int, int]:
320+
- Batched input array of shape: (n_patches, ph, pw)|(n_patches, ph, pw, C)
321+
- Batched repeats array of shape: (n_patches, ph, pw) int32
322+
- Repeat matrix of shape (H + pady, W + padx). Dtype: int32
323+
- The shape of the padded input array.
324+
- nrows
325+
- ncols
326+
"""
327+
shape = arr.shape
328+
if len(shape) == 2:
329+
arr_type = "HW"
330+
elif len(shape) == 3:
331+
arr_type = "HWC"
332+
else:
333+
raise ValueError("`arr` needs to be either 'HW' or 'HWC' shape.")
334+
335+
slices, pady, padx, nrows, ncols = _get_slices(
336+
stride, patch_size, (shape[0], shape[1]), padding
337+
)
338+
339+
padx, modx = divmod(padx, 2)
340+
pady, mody = divmod(pady, 2)
341+
padx += modx
342+
pady += mody
343+
344+
pad = [(pady, pady), (padx, padx)]
345+
if arr_type == "HWC":
346+
pad.append((0, 0))
347+
348+
arr = np.pad(arr, pad, mode="reflect")
349+
350+
# init repeats matrix + add padding repeats
351+
if padding != 0 or padding is None:
352+
repeats = np.ones(arr.shape[:2])
353+
repeats[pady:-pady, padx:-padx] = 0
354+
355+
# corner pads
356+
repeats[:pady, :padx] += 1
357+
repeats[-pady:, -padx:] += 1
358+
repeats[-pady:, :padx] += 1
359+
repeats[:pady, -padx:] += 1
360+
else:
361+
repeats = np.zeros(arr.shape[:2])
362+
363+
patches = []
364+
rep_patches = []
365+
for yslice, xslice in slices.values():
366+
if arr_type == "HW":
367+
patch = arr[yslice, xslice]
368+
elif arr_type == "HWC":
369+
patch = arr[yslice, xslice, ...]
370+
371+
rep_patch = repeats[yslice, xslice]
372+
repeats[yslice, xslice] += 1
373+
374+
patches.append(patch)
375+
rep_patches.append(rep_patch)
376+
377+
return (
378+
np.array(patches, dtype="uint8"),
379+
np.array(rep_patches, dtype="int32"),
380+
repeats.astype("int32"),
381+
arr.shape,
382+
nrows,
383+
ncols,
384+
)
385+
386+
249387
class TilerStitcher:
250388
def __init__(
251389
self,

cellseg_models_pytorch/utils/tests/test_patching.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TilerStitcherTorch,
88
extract_patches_numpy,
99
extract_patches_torch,
10+
get_patches,
1011
stitch_patches_numpy,
1112
stitch_patches_torch,
1213
)
@@ -102,3 +103,16 @@ def test_tilerstitchertorch(rand_tensor, padding):
102103
stitched = ts.backstitch(patches)
103104

104105
assert stitched.shape == rand_tensor.shape
106+
107+
108+
@pytest.mark.parametrize("padding", [10, None])
109+
def test_dict_patching(img_sample, padding):
110+
stride = 32
111+
patch_size = (64, 64)
112+
padding = None
113+
patches, _, repeats, padded_shape, _, _ = get_patches(
114+
img_sample, stride, patch_size, padding
115+
)
116+
117+
assert patches.shape[1:-1] == patch_size
118+
assert repeats.shape == padded_shape[:-1]

0 commit comments

Comments
 (0)