Skip to content

Commit f4b20e4

Browse files
Virginia Brancatovbrancat
authored andcommitted
Pybind interface for filter2D (#714)
* Add bindings for filter2D * add filter2D bindings to Sources.make * Add filter2D bindings in signal and correct typos * Introduce default number of block_rows * Add filter2D unit test * Correcting typo in pybind test * Change pybind implementation to lambda function * Modify unit test for complex data * Reformat pybind code to abide pep8 * Format unit test to abide pep8 notation * Add pybind args and descriptive string * Introduce kernel and input/output dimension checks * Add IndexError as exception in filter2D unit test * Specifying band number * Insert RunTimeError exception as default option in switch block * Switching the include filter2D statement in filter2D.cpp * Simplify unit test using scipy.signal.convolve2D, assertion statements, add magnitude difference check for complex data * Adding const to kernel_rows kernel_columns at C++ level * Adding const to kernel variable in pybind11 interface * Removing unused variables and inputs * Correcting comment on data type * Address input/output different size in pybind11 interface Co-authored-by: vbrancat <[email protected]>
1 parent aa7b396 commit f4b20e4

File tree

8 files changed

+210
-20
lines changed

8 files changed

+210
-20
lines changed

cxx/isce3/signal/filter2D.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#include <isce3/signal/convolve.h>
1010
#include <isce3/signal/decimate.h>
1111

12-
void check_kernels(std::valarray<double>& kernel_columns,
13-
std::valarray<double>& kernel_rows)
12+
void check_kernels(const std::valarray<double>& kernel_columns,
13+
const std::valarray<double>& kernel_rows)
1414
{
1515

1616
size_t ncols_kernel = kernel_columns.size();
@@ -128,8 +128,8 @@ void setup_block_parameters(const int nrows, const int blockRows,
128128
template<typename T>
129129
void isce3::signal::filter2D(isce3::io::Raster& output_raster,
130130
isce3::io::Raster& input_raster,
131-
std::valarray<double>& kernel_columns,
132-
std::valarray<double>& kernel_rows, int block_rows)
131+
const std::valarray<double>& kernel_columns,
132+
const std::valarray<double>& kernel_rows, int block_rows)
133133
{
134134

135135
bool do_decimate = false;
@@ -164,8 +164,8 @@ template<typename T>
164164
void isce3::signal::filter2D(isce3::io::Raster& output_raster,
165165
isce3::io::Raster& input_raster,
166166
isce3::io::Raster& mask_raster,
167-
std::valarray<double>& kernel_columns,
168-
std::valarray<double>& kernel_rows, int block_rows)
167+
const std::valarray<double>& kernel_columns,
168+
const std::valarray<double>& kernel_rows, int block_rows)
169169
{
170170

171171
std::cout << "A mask is provided. The input will be masked before filtering"
@@ -188,8 +188,8 @@ template<typename T>
188188
void isce3::signal::filter2D(isce3::io::Raster& output_raster,
189189
isce3::io::Raster& input_raster,
190190
isce3::io::Raster& mask_raster,
191-
std::valarray<double>& kernel_columns,
192-
std::valarray<double>& kernel_rows,
191+
const std::valarray<double>& kernel_columns,
192+
const std::valarray<double>& kernel_rows,
193193
const bool do_decimate, const bool mask_data,
194194
int block_rows)
195195
{
@@ -334,18 +334,18 @@ void isce3::signal::filter2D(isce3::io::Raster& output_raster,
334334
template void isce3::signal::filter2D<T>( \
335335
isce3::io::Raster & output_raster, \
336336
isce3::io::Raster & input_raster, \
337-
std::valarray<double> & kernel_columns, \
338-
std::valarray<double> & kernel_rows, int block_rows); \
337+
const std::valarray<double> & kernel_columns, \
338+
const std::valarray<double> & kernel_rows, int block_rows); \
339339
template void isce3::signal::filter2D<T>( \
340340
isce3::io::Raster & output_raster, \
341341
isce3::io::Raster & input_raster, isce3::io::Raster & mask_raster, \
342-
std::valarray<double> & kernel_columns, \
343-
std::valarray<double> & kernel_rows, int block_rows); \
342+
const std::valarray<double> & kernel_columns, \
343+
const std::valarray<double> & kernel_rows, int block_rows); \
344344
template void isce3::signal::filter2D<T>( \
345345
isce3::io::Raster & output_raster, \
346346
isce3::io::Raster & input_raster, isce3::io::Raster & mask_raster, \
347-
std::valarray<double> & kernel_columns, \
348-
std::valarray<double> & kernel_rows, const bool do_decimate, \
347+
const std::valarray<double> & kernel_columns, \
348+
const std::valarray<double> & kernel_rows, const bool do_decimate, \
349349
const bool mask, int block_rows)
350350

351351
SPECIALIZE_FILTER(float);

cxx/isce3/signal/filter2D.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ namespace isce3 { namespace signal {
1616
*/
1717
template<typename T>
1818
void filter2D(isce3::io::Raster& output_raster, isce3::io::Raster& input_raster,
19-
std::valarray<double>& kernel_columns,
20-
std::valarray<double>& kernel_rows, int block_rows = 1000);
19+
const std::valarray<double>& kernel_columns,
20+
const std::valarray<double>& kernel_rows, int block_rows = 1000);
2121

2222
/**
2323
* filters real or complex data by convolving two 1D separable kernels in
@@ -32,8 +32,8 @@ void filter2D(isce3::io::Raster& output_raster, isce3::io::Raster& input_raster,
3232
template<typename T>
3333
void filter2D(isce3::io::Raster& output_raster, isce3::io::Raster& input_raster,
3434
isce3::io::Raster& mask_raster,
35-
std::valarray<double>& kernel_columns,
36-
std::valarray<double>& kernel_rows, int block_rows = 1000);
35+
const std::valarray<double>& kernel_columns,
36+
const std::valarray<double>& kernel_rows, int block_rows = 1000);
3737

3838
/**
3939
* filters real or complex data by convolving two 1D separable kernels in
@@ -50,8 +50,8 @@ void filter2D(isce3::io::Raster& output_raster, isce3::io::Raster& input_raster,
5050
template<typename T>
5151
void filter2D(isce3::io::Raster& output_raster, isce3::io::Raster& input_raster,
5252
isce3::io::Raster& mask_raster,
53-
std::valarray<double>& kernel_columns,
54-
std::valarray<double>& kernel_rows, const bool do_decimate,
53+
const std::valarray<double>& kernel_columns,
54+
const std::valarray<double>& kernel_rows, const bool do_decimate,
5555
const bool mask = true, int block_rows = 1000);
5656

5757
}} // namespace isce3::signal

python/extensions/pybind_isce3/Sources.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ signal/Covariance.cpp
5656
signal/Crossmul.cpp
5757
signal/CrossMultiply.cpp
5858
signal/flatten.cpp
59+
signal/filter2D.cpp
5960
product/GeoGridParameters.cpp
6061
product/product.cpp
6162
product/RadarGridParameters.cpp
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "filter2D.h"
2+
3+
#include <gdal_priv.h>
4+
#include <valarray>
5+
#include <pybind11/stl.h>
6+
#include <pybind11/complex.h>
7+
8+
#include <isce3/signal/filter2D.h>
9+
#include <isce3/except/Error.h>
10+
#include <isce3/io/Raster.h>
11+
12+
namespace py = pybind11;
13+
14+
using isce3::signal::filter2D;
15+
using isce3::io::Raster;
16+
17+
void addbinding_filter2D(py::module& m)
18+
{
19+
m
20+
.def("filter2D", [](
21+
Raster& output,
22+
Raster& input,
23+
const std::valarray<double> & kernel_columns,
24+
const std::valarray<double> & kernel_rows,
25+
int block_rows) {
26+
int band=1;
27+
auto in_type = input.dtype(band);
28+
auto out_type = output.dtype(band);
29+
// ensure input and output data type match
30+
if (in_type != out_type)
31+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
32+
"Input and output data type do not match");
33+
// check kernel_columns dimensions
34+
if (kernel_columns.size() > input.width())
35+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
36+
"kernel column width > input width");
37+
// check kernel_rows dimensions
38+
if (kernel_rows.size() > input.length())
39+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
40+
"kernel rows width > input length");
41+
if (output.length() != input.length())
42+
if (output.length() != input.length() / kernel_rows.size())
43+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
44+
"output length not equal to input length or not"
45+
"equal to input length divided by kernel row size");
46+
if (output.width() != input.width())
47+
if (output.width() != output.width() / kernel_columns.size())
48+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
49+
"output width not equal to input width or not"
50+
"equal to input width divided by kernel columns size");
51+
// ensure types match
52+
switch (in_type) {
53+
case GDT_Float32: filter2D<float>(output, input, kernel_columns,
54+
kernel_rows, block_rows);
55+
return;
56+
case GDT_Float64: filter2D<double>(output, input, kernel_columns,
57+
kernel_rows, block_rows);
58+
return;
59+
case GDT_CFloat32: filter2D<std::complex<float>>(output, input,
60+
kernel_columns, kernel_rows, block_rows);
61+
return;
62+
case GDT_CFloat64: filter2D<std::complex<double>>(output, input,
63+
kernel_columns, kernel_rows, block_rows);
64+
return;
65+
default: throw isce3::except::RuntimeError(ISCE_SRCINFO(),
66+
"unsupported GDAL datatype");
67+
}
68+
69+
},
70+
py::arg("output"),
71+
py::arg("input"),
72+
py::arg("kernel_columns"),
73+
py::arg("kernel_rows"),
74+
py::arg("block_rows"),
75+
"Filter real or complex data by convolving two separable 1D kernels")
76+
;
77+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
void addbinding_filter2D(pybind11::module& m);

python/extensions/pybind_isce3/signal/signal.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "CrossMultiply.h"
66
#include "Crossmul.h"
77
#include "flatten.h"
8+
#include "filter2D.h"
89

910
namespace py = pybind11;
1011

@@ -25,6 +26,7 @@ void addsubmodule_signal(py::module & m)
2526
addbinding(pyCrossmul);
2627
addbinding(pyCrossMultiply);
2728
addbinding_flatten(m_signal);
29+
addbinding_filter2D(m_signal);
2830
addbinding_convolve2D<float>(m_signal);
2931
addbinding_convolve2D<std::complex<float>>(m_signal);
3032
addbinding_convolve2D<double>(m_signal);

tests/python/extensions/pybind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ signal/convolve2D.py
3737
signal/covariance.py
3838
signal/crossmul.py
3939
signal/crossmultiply.py
40+
signal/filter2D.py
4041
product/radargridparameters.py
4142
product/swath.py
4243
unwrap/icu.py
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import pybind_isce3 as isce3
2+
import numpy as np
3+
from scipy.signal import convolve2d
4+
5+
from osgeo import gdal
6+
from osgeo import gdal_array
7+
8+
9+
def to_gdal_dataset(outpath, array):
10+
driver = gdal.GetDriverByName("GTiff")
11+
dtype = gdal_array.NumericTypeCodeToGDALTypeCode(array.dtype)
12+
length, width = array.shape
13+
dset = driver.Create(outpath, xsize=width, ysize=length, bands=1,
14+
eType=dtype)
15+
dset.GetRasterBand(1).WriteArray(array)
16+
17+
18+
def open_raster(filepath):
19+
ds = gdal.Open(filepath, gdal.GA_ReadOnly)
20+
array = ds.GetRasterBand(1).ReadAsArray()
21+
22+
return array
23+
24+
25+
def test_run_filter2D():
26+
# Data dimension
27+
length = 200
28+
width = 311
29+
30+
# Kernel dimensions
31+
kernel_length = 3
32+
33+
block = 20
34+
35+
# Create filter kernels
36+
kernel1d = np.ones([kernel_length, 1], dtype=np.float64) / kernel_length
37+
38+
# Create real data to filter
39+
data_real = np.zeros([length, width], dtype=np.float64)
40+
data_cpx = np.zeros([length, width], dtype=np.complex128)
41+
42+
for line in range(0, length):
43+
for col in range(0, width):
44+
data_real[line, col] = line + col
45+
data_cpx[line, col] = np.cos(line * col) + 1.0j * np.sin(line * col)
46+
47+
# Save data
48+
to_gdal_dataset('data.real', data_real)
49+
to_gdal_dataset('data.cpx', data_cpx)
50+
51+
# Filter data
52+
filt_data_real = isce3.io.Raster('data_real.filt', width, length, 1,
53+
gdal.GDT_Float64, "ENVI")
54+
filt_data_cpx = isce3.io.Raster('data_cpx.filt', width, length, 1,
55+
gdal.GDT_CFloat64, "ENVI")
56+
57+
data_raster_real = isce3.io.Raster('data.real')
58+
data_raster_cpx = isce3.io.Raster('data.cpx')
59+
60+
isce3.signal.filter2D(filt_data_real, data_raster_real, kernel1d,
61+
kernel1d, block)
62+
isce3.signal.filter2D(filt_data_cpx, data_raster_cpx, kernel1d,
63+
kernel1d, block)
64+
65+
66+
def test_validate_filter_real():
67+
# Create 2D kernel
68+
kernel_length = 3
69+
kernel_width = 3
70+
kernel = np.ones((kernel_length, kernel_width), dtype=np.float64) / (
71+
kernel_width * kernel_length)
72+
73+
# Validate filter2D for real data
74+
data = open_raster('data.real')
75+
76+
filt_data = open_raster('data_real.filt')
77+
78+
# Regenerate filtered data
79+
out = convolve2d(data, kernel, mode='same')
80+
81+
diff = np.abs(out - filt_data)
82+
assert diff.max() < 1e-12
83+
84+
85+
def test_validate_filter_complex():
86+
# Create kernel 2 D
87+
kernel_width = 3
88+
kernel_length = 3
89+
kernel = np.ones((kernel_length, kernel_width), dtype=np.float64) / (
90+
kernel_width * kernel_length)
91+
92+
# Open data
93+
data = open_raster('data.cpx')
94+
95+
filt_data = open_raster('data_cpx.filt')
96+
97+
# Regenerate filter data
98+
out = convolve2d(data, kernel, mode='same')
99+
100+
diff_pha = np.angle(filt_data * np.conj(out))
101+
diff_amp = np.abs(filt_data) - np.abs(out)
102+
103+
assert diff_pha.max() < 1e-12
104+
assert diff_amp.max() < 1e-12

0 commit comments

Comments
 (0)