Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/snaphu/_check.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import math

import numpy as np

from .io import InputDataset, OutputDataset
Expand All @@ -15,6 +17,9 @@
]


LARGESHORT = 32000 # needs to match LARGESHORT in snaphu.h


def check_2d_shapes(**shapes: tuple[int, ...]) -> None:
"""
Ensure that the input tuples are valid 2-D shapes.
Expand Down Expand Up @@ -70,6 +75,79 @@ def check_dataset_shapes(
raise ValueError(errmsg)


def check_dataset_sizes(
ntiles: tuple[int, int],
tile_overlap: int | tuple[int, int],
*,
regrow_conncomps: bool = True,
single_tile_reoptimize: bool = False,
**datasets: InputDataset | OutputDataset,
) -> None:
"""
Ensure that one or more datasets have shape that SNAPHU can handle.

Parameters
----------
ntiles : (int, int)
Number of tiles used in each dimension.
tile_overlap: int or (int, int)
Overlap between tiles.
regrow_conncomps : bool
Whether to regrow connected components after tiled unwrapping.
single_tile_reoptimize: bool
Whether to use single tile reoptimization after tiled unwrapping.
**datasets : dict, optional
Datasets to be processed with SNAPHU. The name of each keyword argument
is used to format the error message in case of a size exception.

Raises
------
ValueError
If any dataset had a too large shape.
TypeError
If the tile overlaps type is unknown.
"""
if isinstance(tile_overlap, int):
y_overlap = x_overlap = tile_overlap
elif isinstance(tile_overlap, tuple):
# cannot unpack, tile_overlap shape is checked later
y_overlap = tile_overlap[0]
x_overlap = tile_overlap[1]
else:
msg = f"got {type(tile_overlap)=}, expected int or tuple"
raise TypeError(msg)

for name, arr in datasets.items():
skip_tiling: bool = ntiles == (1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line doesn't depend on name or arr so let's pull it out of the loop.

Do we really need the type hint here? I assume that mypy can infer that skip_tiling is a bool on its own.

if regrow_conncomps or single_tile_reoptimize or skip_tiling:
# a single tile is input for snaphu
if any(n > LARGESHORT for n in arr.shape):
msg = (
f"{name} dataset with shape {arr.shape} exceeds max dimensions"
" supported by SNAPHU. Consider using tiling and disabling the"
" regrow_conncomps and single_tile_reoptimize options"
)
raise ValueError(msg)
if skip_tiling:
continue

# in case tile exceeds max array size
def _calc_tile_shape(array_len: int, num_tiles: int, overlap_len: int) -> int:
# tile shape calc similar to SetupTile in snaphu.c
return math.ceil((array_len + (num_tiles - 1) * overlap_len) / num_tiles)

tile_height = _calc_tile_shape(arr.shape[0], ntiles[0], y_overlap)
tile_width = _calc_tile_shape(arr.shape[1], ntiles[1], x_overlap)
tile_shapes_max = (tile_height, tile_width)
if any(n > LARGESHORT for n in tile_shapes_max):
msg = (
f"tile dimensions for {name} dataset are {tile_shapes_max}, which"
" exceed the max supported by SNAPHU. Consider increasing number"
" of tiles"
)
raise ValueError(msg)


def check_complex_dtype(**datasets: InputDataset | OutputDataset) -> None:
"""
Ensure that one or more datasets is complex-valued.
Expand Down
10 changes: 10 additions & 0 deletions src/snaphu/_unwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
check_complex_dtype,
check_cost_mode,
check_dataset_shapes,
check_dataset_sizes,
check_float_dtype,
check_integer_dtype,
)
Expand Down Expand Up @@ -332,6 +333,15 @@ def unwrap( # type: ignore[no-untyped-def]
ntiles=ntiles, tile_overlap=tile_overlap, nproc=nproc
)

# Ensure that the dataset (and tile) dimensions are not too large.
check_dataset_sizes(
ntiles,
tile_overlap,
regrow_conncomps=regrow_conncomps,
single_tile_reoptimize=single_tile_reoptimize,
igram=igram,
)

with scratch_directory(scratchdir, delete=delete_scratch) as dir_:
# Create a raw binary file in the scratch directory for the interferogram and
# copy the input data to it. (`mkstemp` is used to avoid data races in case the
Expand Down
48 changes: 48 additions & 0 deletions test/test_unwrap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from unittest.mock import MagicMock

import numpy as np
import pytest
Expand Down Expand Up @@ -183,6 +184,53 @@ def test_shape_mismatch(self):
with pytest.raises(ValueError, match=pattern):
snaphu.unwrap(igram, corr, nlooks=100.0)

@pytest.mark.parametrize(
(
"arr_shape",
"kwargs",
"err_msg",
),
[
( # conncomps not on, but still too large
(32001, 128),
{"regrow_conncomps": False},
r"igram dataset with shape \(32001, 128\) exceeds max dimensions",
),
( # no regrowing, but tile 1 pixel too large
(63937, 128),
{
"ntiles": (2, 2),
"tile_overlap": (64, 64),
"regrow_conncomps": False,
"single_tile_reoptimize": False,
},
r"tile dimensions for igram dataset are \(32001, 96\)",
),
( # regrow on, single_tile too large due to overlap
(128, 128),
{
"ntiles": (2, 2),
"tile_overlap": (63873, 64),
"single_tile_reoptimize": True,
},
r"tile dimensions for igram dataset are \(32001, 96\)",
),
],
)
def test_shape_too_large(
self, arr_shape: tuple[int, int], kwargs: dict[str, Any], err_msg: str
):
igram = MagicMock(spec=np.ndarray)
igram.shape = arr_shape
igram.ndim = 2
igram.dtype = np.complex64
corr = MagicMock(spec=np.ndarray)
corr.shape = arr_shape
corr.ndim = 2
corr.dtype = np.float32
with pytest.raises(ValueError, match=err_msg):
snaphu.unwrap(igram, corr, nlooks=1.0, **kwargs)

def test_bad_igram_dtype(self):
shape = (128, 128)
igram = np.empty(shape, dtype=np.float64)
Expand Down