Skip to content

Commit 429a3c9

Browse files
Virginia Brancatovbrancatgmgunter
authored andcommitted
WIP: Pybind11 for Phass unwrapper (#723)
* Pybind11 interface for Phass unwrapper * Unit test for Phass unwrapper * Correct identation of attributes description * isort unit test inputs and removing dset=None * Explicit dataset extensions in unit test and removing unit test dataset cleanup * Checking difference absolute value instead of difference * Using clang-format to format Phass pybind interface * Correcting comment misalignment * Fixup style - Convert tabs to spaces - Remove trailing whitespace - Use consistent spacing around "=" symbols Co-authored-by: vbrancat <[email protected]> Co-authored-by: Geoffrey M Gunter <[email protected]>
1 parent 75786fa commit 429a3c9

File tree

6 files changed

+242
-0
lines changed

6 files changed

+242
-0
lines changed

python/extensions/pybind_isce3/Sources.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ product/RadarGridParameters.cpp
6363
product/Swath.cpp
6464
unwrap/unwrap.cpp
6565
unwrap/ICU.cpp
66+
unwrap/Phass.cpp
6667
isce.cpp
6768
)
6869

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#include "Phass.h"
2+
#include <isce3/io/Raster.h>
3+
4+
namespace py = pybind11;
5+
6+
using isce3::io::Raster;
7+
using isce3::unwrap::phass::Phass;
8+
9+
void addbinding(py::class_<Phass> & pyPhass)
10+
{
11+
pyPhass.doc() = R"(
12+
class for initializing Phass unwrapping algorithm
13+
14+
Attributes
15+
----------
16+
correlation_threshold : float
17+
Correlation threshold increment
18+
good_correlation : float
19+
Good correlation threshold
20+
min_pixels_region : int
21+
Minimum size of a region to be unwrapped
22+
)";
23+
24+
pyPhass
25+
// Constructor
26+
.def(py::init([](const double correlation_threshold,
27+
const double good_correlation,
28+
const int min_pixels_region)
29+
{
30+
Phass phass;
31+
phass.correlationThreshold(correlation_threshold);
32+
phass.goodCorrelation(good_correlation);
33+
phass.minPixelsPerRegion(min_pixels_region);
34+
35+
return phass;
36+
}),
37+
py::arg("correlation_threshold") = 0.2,
38+
py::arg("good_correlation") = 0.7,
39+
py::arg("min_pixels_region") = 200
40+
)
41+
.def("unwrap", py::overload_cast<Raster&, Raster&, Raster&, Raster&>(&Phass::unwrap),
42+
py::arg("phase"),
43+
py::arg("correlation"),
44+
py::arg("unw_igram"),
45+
py::arg("label"),
46+
R"(
47+
Perform phase unwrapping using the Phass algorithm
48+
49+
Parameters
50+
----------
51+
phase: Raster
52+
Input interferometric phase (radians)
53+
correlation: Raster
54+
Input interferometric correlation
55+
unw_igram: Raster
56+
Output unwrapped interferogram
57+
label: Raster
58+
Output connected components
59+
)")
60+
.def("unwrap", py::overload_cast<Raster&, Raster&, Raster&, Raster&, Raster&>(&Phass::unwrap),
61+
py::arg("phase"),
62+
py::arg("power"),
63+
py::arg("correlation"),
64+
py::arg("unw_igram"),
65+
py::arg("label"),
66+
R"(
67+
Perform phase unwrapping using the Phass algorithm
68+
69+
Parameters
70+
----------
71+
phase: Raster
72+
Input interferometric phase (radians)
73+
power: Raster
74+
Power of reference RSLC
75+
correlation: Raster
76+
Input interferometric correlation
77+
unw_igram: Raster
78+
Output unwrapped interferogram
79+
label: Raster
80+
Output connected components
81+
)")
82+
83+
// Properties
84+
.def_property("correlation_threshold",
85+
py::overload_cast<>(&Phass::correlationThreshold, py::const_),
86+
py::overload_cast<double>(&Phass::correlationThreshold))
87+
.def_property("good_correlation",
88+
py::overload_cast<>(&Phass::goodCorrelation, py::const_),
89+
py::overload_cast<double>(&Phass::goodCorrelation))
90+
.def_property("min_pixels_region",
91+
py::overload_cast<>(&Phass::minPixelsPerRegion, py::const_),
92+
py::overload_cast<int>(&Phass::minPixelsPerRegion))
93+
;
94+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <isce3/unwrap/phass/Phass.h>
4+
#include <pybind11/pybind11.h>
5+
6+
void addbinding(pybind11::class_<isce3::unwrap::phass::Phass> &);

python/extensions/pybind_isce3/unwrap/unwrap.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "unwrap.h"
22
#include "ICU.h"
3+
#include "Phass.h"
34

45

56
namespace py = pybind11;
@@ -10,8 +11,10 @@ void addsubmodule_unwrap(py::module & m)
1011

1112
// forward declare bound classes
1213
py::class_<isce3::unwrap::icu::ICU> pyICU(m_unwrap, "ICU");
14+
py::class_<isce3::unwrap::phass::Phass> pyPhass(m_unwrap, "Phass");
1315

1416
// add bindings
1517
addbinding(pyICU);
18+
addbinding(pyPhass);
1619

1720
}

tests/python/extensions/pybind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ signal/filter2D.py
4141
product/radargridparameters.py
4242
product/swath.py
4343
unwrap/icu.py
44+
unwrap/phass.py
4445
geometry/ltpcoordinates.py
4546
geometry/pntintersect.py
4647
antenna/frame.py
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
'''
2+
Unit test for Phass unwrapper
3+
'''
4+
5+
import os
6+
7+
import numpy as np
8+
import numpy.testing as npt
9+
import pybind_isce3 as isce3
10+
from osgeo import gdal, gdal_array
11+
12+
width = 256
13+
length = 1100
14+
15+
16+
def to_gdal_dataset(outpath, array):
17+
driver = gdal.GetDriverByName("Gtiff")
18+
dtype = gdal_array.NumericTypeCodeToGDALTypeCode(array.dtype)
19+
length, width = array.shape
20+
dset = driver.Create(outpath, xsize=width, ysize=length, bands=1,
21+
eType=dtype)
22+
dset.GetRasterBand(1).WriteArray(array)
23+
24+
25+
def create_datasets():
26+
# Generate interferogram
27+
xx = np.linspace(0.0, 50.0, width)
28+
yy = np.linspace(0.0, 50.0, length)
29+
30+
x, y = np.meshgrid(xx, yy)
31+
igram = np.exp(1j * (x + y))
32+
phase = np.angle(igram)
33+
34+
to_gdal_dataset('phase.tif', phase)
35+
36+
# Generate coherence
37+
corr = np.zeros((length, width), dtype=np.float32)
38+
corr[100:900, 50:100] = 1.0
39+
corr[100:900, 150:200] = 1.0
40+
corr[900:950, 50:200] = 1.0
41+
corr[1000:1050, 50:200] = 1.0
42+
43+
to_gdal_dataset('coherence.tif', corr)
44+
45+
46+
def read_raster(infile):
47+
ds = gdal.Open(infile, gdal.GA_ReadOnly)
48+
array = ds.GetRasterBand(1).ReadAsArray()
49+
ds = None
50+
return array
51+
52+
53+
def test_getter_setter():
54+
phass = isce3.unwrap.Phass()
55+
56+
phass.correlation_threshold = 0.5
57+
npt.assert_equal(phass.correlation_threshold, 0.5)
58+
59+
phass.good_correlation = 0.6
60+
npt.assert_equal(phass.good_correlation, 0.6)
61+
62+
phass.min_pixels_region = 100
63+
npt.assert_equal(phass.min_pixels_region, 100)
64+
65+
66+
def test_run_phass():
67+
# Create interferogram and coherence
68+
create_datasets()
69+
70+
# Open created datasets as ISCE3 rasters
71+
phase = isce3.io.Raster('phase.tif')
72+
corr = isce3.io.Raster('coherence.tif')
73+
74+
# Generate output rasters
75+
unwRaster = isce3.io.Raster('unw.f4', phase.width,
76+
phase.length, 1, gdal.GDT_Float32, "ENVI")
77+
labelRaster = isce3.io.Raster('label.u1', phase.width,
78+
phase.length, 1, gdal.GDT_Byte, "ENVI")
79+
80+
# Configure and run Phass
81+
phass = isce3.unwrap.Phass()
82+
phass.unwrap(phase, corr, unwRaster, labelRaster)
83+
84+
85+
def test_check_unwrapped_phase():
86+
# Read interferogram and connected components
87+
label = read_raster('label.u1')
88+
unw = read_raster('unw.f4')
89+
90+
# Generate reference interferogram
91+
xx = np.linspace(0.0, 50.0, width)
92+
yy = np.linspace(0.0, 50.0, length)
93+
94+
x, y = np.meshgrid(xx, yy)
95+
ref_unw = x + y
96+
97+
# Reference to each label differently
98+
labels = np.unique(label)
99+
diff = (ref_unw[np.where(label == labels[1])] - ref_unw[102, 52]) - \
100+
(unw[np.where(label == labels[1])] - unw[102, 52])
101+
npt.assert_array_less(np.abs(diff).max(), 1e-5)
102+
103+
diff = (ref_unw[np.where(label == labels[2])] - ref_unw[1002, 52]) - \
104+
(unw[np.where(label == labels[2])] - unw[1002, 52])
105+
npt.assert_array_less(np.abs(diff).max(), 1e-5)
106+
107+
108+
def test_check_labels():
109+
# Open labels
110+
label = read_raster('label.u1')
111+
l, w = label.shape
112+
113+
npt.assert_equal(w, width)
114+
npt.assert_equal(l, length)
115+
116+
# Check all pixels within the U
117+
# have the same label
118+
119+
npt.assert_equal(np.all(label[100:900, 50:100] == label[100, 50]), True)
120+
npt.assert_equal(np.all(label[100:900, 150:200] == label[100, 50]), True)
121+
npt.assert_equal(np.all(label[900:950, 50:200] == label[900, 50]), True)
122+
npt.assert_equal(np.all(label[1000:1050, 50:200] == label[1000, 50]), True)
123+
124+
# Check different connected components
125+
# have different labels
126+
npt.assert_raises(AssertionError, npt.assert_array_equal, label[100, 50],
127+
label[1000, 50])
128+
npt.assert_raises(AssertionError, npt.assert_array_equal, label[900, 50],
129+
label[1000, 50])
130+
131+
132+
if __name__ == '__main__':
133+
test_getter_setter()
134+
test_run_phass()
135+
test_check_unwrapped_phase()
136+
test_check_labels()
137+

0 commit comments

Comments
 (0)