|
1 | | -from typing import Tuple |
| 1 | +from typing import Dict, Tuple |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import torch |
|
10 | 10 | "extract_patches_numpy", |
11 | 11 | "stitch_patches_numpy", |
12 | 12 | "TilerStitcher", |
| 13 | + "get_patches", |
13 | 14 | ] |
14 | 15 |
|
15 | 16 |
|
@@ -246,6 +247,143 @@ def stitch_patches_torch( |
246 | 247 | return output |
247 | 248 |
|
248 | 249 |
|
| 250 | +def _get_margins_and_pad( |
| 251 | + first_endpoint: int, img_size: int, stride: int, pad: int = None |
| 252 | +) -> Tuple[int, int]: |
| 253 | + """Get the number of slices needed for one direction and the overlap.""" |
| 254 | + pad = int(pad) if pad is not None else 20 # at least some padding needed |
| 255 | + img_size += pad |
| 256 | + |
| 257 | + n = 1 |
| 258 | + mod = 0 |
| 259 | + end = first_endpoint |
| 260 | + while True: |
| 261 | + n += 1 |
| 262 | + end += stride |
| 263 | + |
| 264 | + if end > img_size: |
| 265 | + mod = end - img_size |
| 266 | + break |
| 267 | + elif end == img_size: |
| 268 | + break |
| 269 | + |
| 270 | + return n, mod + pad |
| 271 | + |
| 272 | + |
| 273 | +def _get_slices( |
| 274 | + stride: int, |
| 275 | + patch_size: Tuple[int, int], |
| 276 | + img_size: Tuple[int, int], |
| 277 | + pad: int = None, |
| 278 | +) -> Tuple[Dict[str, slice], int, int]: |
| 279 | + """Get all the overlapping slices in a dictionary and the needed paddings.""" |
| 280 | + y_end, x_end = patch_size |
| 281 | + nrows, pady = _get_margins_and_pad(y_end, img_size[0], stride, pad=pad) |
| 282 | + ncols, padx = _get_margins_and_pad(x_end, img_size[1], stride, pad=pad) |
| 283 | + |
| 284 | + xyslices = {} |
| 285 | + for row in range(nrows): |
| 286 | + for col in range(ncols): |
| 287 | + y_start = row * stride |
| 288 | + y_end = y_start + patch_size[0] |
| 289 | + x_start = col * stride |
| 290 | + x_end = x_start + patch_size[1] |
| 291 | + xyslices[f"y-{y_start}_x-{x_start}"] = ( |
| 292 | + slice(y_start, y_end), |
| 293 | + slice(x_start, x_end), |
| 294 | + ) |
| 295 | + |
| 296 | + return xyslices, pady, padx, nrows, ncols |
| 297 | + |
| 298 | + |
| 299 | +def get_patches( |
| 300 | + arr: np.ndarray, stride: int, patch_size: Tuple[int, int], padding: int = None |
| 301 | +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple[int, ...], int, int]: |
| 302 | + """Patch an input array to overlapping or non-overlapping patches. |
| 303 | +
|
| 304 | + NOTE: some padding is applied by default to make the arr divisible by patch_size. |
| 305 | +
|
| 306 | + Parameters |
| 307 | + ---------- |
| 308 | + arr : np.ndarray |
| 309 | + An array of shape: (H, W, C) or (H, W). |
| 310 | + stride : int |
| 311 | + Stride of the sliding window. |
| 312 | + patch_size : Tuple[int, int] |
| 313 | + Height and width of the patch |
| 314 | + padding : int, optional |
| 315 | + Size of reflection padding. |
| 316 | +
|
| 317 | + Returns |
| 318 | + -------- |
| 319 | + Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple[int, ...], int, int]: |
| 320 | + - Batched input array of shape: (n_patches, ph, pw)|(n_patches, ph, pw, C) |
| 321 | + - Batched repeats array of shape: (n_patches, ph, pw) int32 |
| 322 | + - Repeat matrix of shape (H + pady, W + padx). Dtype: int32 |
| 323 | + - The shape of the padded input array. |
| 324 | + - nrows |
| 325 | + - ncols |
| 326 | + """ |
| 327 | + shape = arr.shape |
| 328 | + if len(shape) == 2: |
| 329 | + arr_type = "HW" |
| 330 | + elif len(shape) == 3: |
| 331 | + arr_type = "HWC" |
| 332 | + else: |
| 333 | + raise ValueError("`arr` needs to be either 'HW' or 'HWC' shape.") |
| 334 | + |
| 335 | + slices, pady, padx, nrows, ncols = _get_slices( |
| 336 | + stride, patch_size, (shape[0], shape[1]), padding |
| 337 | + ) |
| 338 | + |
| 339 | + padx, modx = divmod(padx, 2) |
| 340 | + pady, mody = divmod(pady, 2) |
| 341 | + padx += modx |
| 342 | + pady += mody |
| 343 | + |
| 344 | + pad = [(pady, pady), (padx, padx)] |
| 345 | + if arr_type == "HWC": |
| 346 | + pad.append((0, 0)) |
| 347 | + |
| 348 | + arr = np.pad(arr, pad, mode="reflect") |
| 349 | + |
| 350 | + # init repeats matrix + add padding repeats |
| 351 | + if padding != 0 or padding is None: |
| 352 | + repeats = np.ones(arr.shape[:2]) |
| 353 | + repeats[pady:-pady, padx:-padx] = 0 |
| 354 | + |
| 355 | + # corner pads |
| 356 | + repeats[:pady, :padx] += 1 |
| 357 | + repeats[-pady:, -padx:] += 1 |
| 358 | + repeats[-pady:, :padx] += 1 |
| 359 | + repeats[:pady, -padx:] += 1 |
| 360 | + else: |
| 361 | + repeats = np.zeros(arr.shape[:2]) |
| 362 | + |
| 363 | + patches = [] |
| 364 | + rep_patches = [] |
| 365 | + for yslice, xslice in slices.values(): |
| 366 | + if arr_type == "HW": |
| 367 | + patch = arr[yslice, xslice] |
| 368 | + elif arr_type == "HWC": |
| 369 | + patch = arr[yslice, xslice, ...] |
| 370 | + |
| 371 | + rep_patch = repeats[yslice, xslice] |
| 372 | + repeats[yslice, xslice] += 1 |
| 373 | + |
| 374 | + patches.append(patch) |
| 375 | + rep_patches.append(rep_patch) |
| 376 | + |
| 377 | + return ( |
| 378 | + np.array(patches, dtype="uint8"), |
| 379 | + np.array(rep_patches, dtype="int32"), |
| 380 | + repeats.astype("int32"), |
| 381 | + arr.shape, |
| 382 | + nrows, |
| 383 | + ncols, |
| 384 | + ) |
| 385 | + |
| 386 | + |
249 | 387 | class TilerStitcher: |
250 | 388 | def __init__( |
251 | 389 | self, |
|
0 commit comments