Skip to content

Commit 583fde5

Browse files
committed
API to convolve arrays/tensors with continuous convolution kernels (#378)
This PR adds new function ``drjit.convolve()`` that repurposes the ``drjit.resample()`` infrastructure to convolve one or more axes of a Dr.Jit array or tensor with a 1D filter. The user can choose one of multiple presets or specify a custom functions.
1 parent 73261b6 commit 583fde5

File tree

6 files changed

+159
-21
lines changed

6 files changed

+159
-21
lines changed

docs/reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ Rearranging array contents
116116
.. autofunction:: tile
117117
.. autofunction:: repeat
118118
.. autofunction:: resample
119+
.. autofunction:: convolve
119120

120121
Random number generation
121122
------------------------

drjit/__init__.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,8 +1830,7 @@ def backward(self):
18301830
def resample(
18311831
source: ArrayT,
18321832
shape: Sequence[int],
1833-
*,
1834-
filter: Union[Literal["box", "linear", "hamming", "cubic", "lanczos"], Callable[[float], float]] = "cubic",
1833+
filter: Union[Literal["box", "linear", "hamming", "cubic", "lanczos", "gaussian"], Callable[[float], float]] = "cubic",
18351834
filter_radius: Optional[float] = None
18361835
) -> ArrayT:
18371836
"""
@@ -1872,16 +1871,21 @@ def resample(
18721871
- ``"hamming"``: uses the same number of input samples as ``"linear"`` but
18731872
better preserves sharpness when downscaling. Do not use for upscaling.
18741873
1875-
- ``"cubic"``: use cubic filter kernel that uses :math:`4^n`
1874+
- ``"cubic"``: use cubic filter kernel that queries :math:`4^n`
18761875
neighbors to reconstruct each output sample when upsampling. Produces
18771876
high-quality results. This is the default.
18781877
1879-
- ``"lanczos"``: use a windowed Lanczos filter that uses :math:`6^n`
1878+
- ``"lanczos"``: use a windowed Lanczos filter that queries :math:`6^n`
18801879
neighbors to reconstruct each output sample when upsampling. This is the
18811880
best filter for smooth signals, but also the costliest. The Lanczos
18821881
filter is susceptible to ringing when the input array contains
18831882
discontinuities.
18841883
1884+
- ``"gaussian"``: use a Gaussian filter that queries :math:4^n` neighbors
1885+
to reconstruct each output sample when upsampling. The kernel has a
1886+
standard deviation of 0.5 and is truncated after 4 standard deviations.
1887+
This filter is mainly useful when intending to blur a signal.
1888+
18851889
- Besides the above choices, it is also possible to specify a custom filter.
18861890
To do so, use the ``filter`` argument to pass a Python callable with
18871891
signature ``Callable[[float], float]``. In this case, you must also
@@ -1961,6 +1965,95 @@ def resample(
19611965
else:
19621966
return value
19631967

1968+
def convolve(
1969+
source: ArrayT,
1970+
filter: Union[Literal["box", "linear", "hamming", "cubic", "lanczos", "gaussian"], Callable[[float], float]],
1971+
filter_radius: float,
1972+
axis: Union[int, Tuple[int, ...], None] = None
1973+
) -> ArrayT:
1974+
"""
1975+
Convolve one or more axes of an input array/tensor with a 1D filter
1976+
1977+
This function filters one more axes of a Dr.Jit array or tensor, for
1978+
example to convolve an image with a 2D Gaussian filter to blur spatial
1979+
detail.
1980+
1981+
.. code-block:: python
1982+
1983+
image: TensorXf = ... # a RGB image
1984+
1985+
blured_image = dr.convolve(
1986+
image,
1987+
filter='gaussian',
1988+
filter_radius=10
1989+
)
1990+
1991+
The filter weights are renormalized to reduce edge effects near the
1992+
boundary of the array.
1993+
1994+
The function supports a set of provided filters, and custom filters
1995+
can also be specified. This works analogously to the :py:func:`resample`
1996+
function, please refer to its documentation for detail.
1997+
1998+
Args:
1999+
source (dr.ArrayBase): The Dr.Jit tensor or 1D array to be resampled.
2000+
2001+
filter (str | Callable[[float], float])
2002+
The desired reconstruction filter, see the above text for an overview.
2003+
Alternatively, a custom reconstruction filter function can also be
2004+
specified.
2005+
2006+
filter_radius (float)
2007+
The radius of the continous function to be used in the convolution.
2008+
2009+
axis (int | tuple[int, ...] | ... | None): The axis or set of axes
2010+
along which to convolve. The default argument ``axis=None`` causes all
2011+
axes to be convolved. Negative values count from the last dimension.
2012+
2013+
Returns:
2014+
drjit.ArrayBase: The resampled output array. Its type matches ``source``.
2015+
"""
2016+
2017+
shape = source.shape
2018+
strides = _compute_strides(shape)
2019+
ndim = len(shape)
2020+
tp = type(source)
2021+
value = source.array
2022+
2023+
if axis is None:
2024+
axis = tuple(range(ndim))
2025+
elif isinstance(axis, int):
2026+
axis = (axis, )
2027+
2028+
for i in axis:
2029+
if i < 0:
2030+
i = ndim + i
2031+
res = shape[i]
2032+
2033+
# Cache resampler in case it can be reused
2034+
key = (res, res, filter, filter_radius)
2035+
2036+
resampler = _resample_cache.get(key, None)
2037+
if resampler is None:
2038+
resampler = detail.Resampler(
2039+
source_res=res,
2040+
target_res=res,
2041+
filter=filter,
2042+
filter_radius=filter_radius,
2043+
convolve=True
2044+
)
2045+
_resample_cache[key] = resampler
2046+
2047+
value = custom(_ResampleOp,
2048+
resampler=resampler,
2049+
source=value,
2050+
stride=strides[i])
2051+
2052+
if is_tensor_v(tp):
2053+
return tp(value, shape)
2054+
else:
2055+
return value
2056+
19642057

19652058
def _normalize_axis_tuple(t: Union[int, Tuple[int, ...]], ndim: int, name: str) -> List[int]:
19662059
if isinstance(t, int):

include/drjit/resample.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class DRJIT_EXTRA_EXPORT Resampler {
3434
* Create a Resampler that uses a predefined reconstruction filter to
3535
* resample a signal from resolution ``source_res`` to ``target_res``.
3636
*
37-
* The following options are available:
37+
* The following ``filter`` presets are available:
3838
*
3939
* - ``"box"``: use nearest-neighbor interpolation/averaging. This is
4040
* very efficient but generally produces sub-par output that is either
@@ -44,16 +44,29 @@ class DRJIT_EXTRA_EXPORT Resampler {
4444
* reconstruct each output sample when upsampling. Tends to produce
4545
* relatively blurry results.
4646
*
47-
* - ``"cubic"``: use cubic filter kernel that uses 4 neighbors to
47+
* - ``"hamming"``: uses the same number of input samples as ``"linear"``
48+
* but better preserves sharpness when downscaling. Do not use for
49+
* upscaling.
50+
*
51+
* - ``"cubic"``: use cubic filter kernel that queries 4 neighbors to
4852
* reconstruct each output sample when upsampling. Produces high-quality
4953
* results.
5054
*
51-
* - ``"lanczos"``: use a windowed Lanczos filter that uses 6 neighbors to
52-
* reconstruct each output sample when upsampling. This is the best filter
53-
* for smooth signals, but also the costliest. The Lanczos filter is
54-
* susceptible to ringing when the input array contains discontinuities.
55+
* - ``"lanczos"``: use a windowed Lanczos filter that queries 6 neighbors
56+
* to reconstruct each output sample when upsampling. This is the best
57+
* filter for smooth signals, but also the costliest. The Lanczos filter
58+
* is susceptible to ringing when the input array contains discontinuities.
59+
*
60+
* - ``"gaussian"``: use a Gaussian filter that queries 4 neighbors to
61+
* reconstruct each output sample when upsampling. The Gaussian has a
62+
* standard deviation of 0.5 and is truncated after 4 standard
63+
* deviations. This filter is mainly useful when intending to blur a signal.
64+
*
65+
* The optional ``radius_scale`` parameter can be used to scale the
66+
* filter kernel radius.
5567
*/
56-
Resampler(uint32_t source_res, uint32_t target_res, const char *filter);
68+
Resampler(uint32_t source_res, uint32_t target_res, const char *filter,
69+
double radius_scale = 1.0);
5770

5871
/**
5972
* \brief Construct a Resampler using a custom filter kernel.

src/extra/resample.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <drjit/resample.h>
1212
#include <drjit/while_loop.h>
13+
#include <drjit/math.h>
1314
#include <nanothread/nanothread.h>
1415
#include <cmath>
1516
#include <algorithm>
@@ -28,7 +29,7 @@ struct Resampler::Impl {
2829
mutable std::any weights_cache;
2930

3031
Impl(uint32_t source_res, uint32_t target_res, Resampler::Filter filter,
31-
const void *payload, double radius)
32+
const void *payload, double radius, double radius_scale)
3233
: source_res(source_res), target_res(target_res) {
3334
if (source_res == 0 || target_res == 0)
3435
throw std::runtime_error("drjit.Resampler(): source/target resolution cannot be zero!");
@@ -41,7 +42,14 @@ struct Resampler::Impl {
4142
radius *= scale;
4243
}
4344

45+
if (source_res == target_res) {
46+
// Convolution mode, adapt to filter size scale factor
47+
radius *= radius_scale;
48+
filter_scale /= radius_scale;
49+
}
50+
4451
taps = (uint32_t) std::ceil(radius * 2);
52+
4553
offset = unique_ptr<uint32_t[]>(new uint32_t[target_res]);
4654
weights = unique_ptr<double[]>(new double[taps * target_res]);
4755

@@ -121,7 +129,7 @@ static inline double sinc(double x) {
121129
return std::sin(x) / x;
122130
}
123131

124-
Resampler::Resampler(uint32_t source_res, uint32_t target_res, const char *filter) {
132+
Resampler::Resampler(uint32_t source_res, uint32_t target_res, const char *filter, double radius_scale) {
125133
Resampler::Filter filter_cb = nullptr;
126134
double radius = 0.0;
127135

@@ -167,18 +175,29 @@ Resampler::Resampler(uint32_t source_res, uint32_t target_res, const char *filte
167175
return sinc(x) * sinc(x * (1.0 / 3.0));
168176
};
169177
radius = 3.f;
178+
} else if (strcmp(filter, "gaussian") == 0) {
179+
filter_cb = [](double x, const void *) -> double {
180+
if (x < -2.0 || x >= 2.0)
181+
return 0.0;
182+
double stddev = .5,
183+
alpha = -1.0 / (2.0 * square(stddev));
184+
return maximum(0.f, exp(alpha * square(x)) - exp(alpha * square(2.0)));
185+
186+
187+
};
188+
radius = 2.f;
170189
} else {
171190
throw std::runtime_error("'filter': unknown value ('box', 'linear', "
172191
"'hamming', 'cubic', and 'lanczos' are supported).");
173192
}
174193

175-
d = new Impl(source_res, target_res, filter_cb, nullptr, radius);
194+
d = new Impl(source_res, target_res, filter_cb, nullptr, radius, radius_scale);
176195
}
177196

178197
Resampler::Resampler(uint32_t source_res, uint32_t target_res,
179198
Resampler::Filter filter, const void *payload,
180199
double radius)
181-
: d(new Impl(source_res, target_res, filter, payload, radius)) {
200+
: d(new Impl(source_res, target_res, filter, payload, radius, 1.0)) {
182201
}
183202

184203
Resampler::~Resampler() { }

src/python/resample.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <drjit/resample.h>
1212
#include <nanobind/stl/string.h>
13+
#include <nanobind/stl/optional.h>
1314
#include "common.h"
1415

1516
void export_resample(nb::module_ &) {
@@ -18,19 +19,19 @@ void export_resample(nb::module_ &) {
1819

1920
auto resampler = nb::class_<Resampler>(detail, "Resampler")
2021
.def("__init__", [](Resampler *self, uint32_t source_res, uint32_t target_res,
21-
const char *filter, nb::handle filter_radius) {
22-
if (!filter_radius.is_none())
22+
const char *filter, std::optional<double> filter_radius, bool convolve) {
23+
if (filter_radius.has_value() && !convolve)
2324
nb::raise("drjit.Resampler(): 'filter_radius' must be None when using a filter preset.");
24-
new (self) Resampler(source_res, target_res, filter);
25-
}, "source_res"_a, "target_res"_a, "filter"_a, "filter_radius"_a = nb::none())
25+
new (self) Resampler(source_res, target_res, filter, filter_radius.has_value() ? filter_radius.value() : 1.0);
26+
}, "source_res"_a, "target_res"_a, "filter"_a, "filter_radius"_a = nb::none(), "convolve"_a = false)
2627
.def("__init__", [](Resampler *self, uint32_t source_res, uint32_t target_res,
27-
nb::typed<nb::callable, float, float> filter, double filter_radius) {
28+
nb::typed<nb::callable, float, float> filter, double filter_radius, bool) {
2829
Resampler::Filter filter_cb = [](double v, const void *ptr) -> double {
2930
return nb::cast<double>(nb::handle((PyObject *) ptr)(v));
3031
};
3132
new (self) Resampler(source_res, target_res, filter_cb,
3233
filter.ptr(), filter_radius);
33-
}, "source_res"_a, "target_res"_a, "filter"_a, "filter_radius"_a)
34+
}, "source_res"_a, "target_res"_a, "filter"_a, "filter_radius"_a, "convolve"_a = false)
3435
#if defined(DRJIT_ENABLE_CUDA)
3536
.def("resample_fwd",
3637
(dr::CUDAArray<dr::half>(Resampler::*)(const dr::CUDAArray<dr::half> &, uint32_t) const) &Resampler::resample_fwd,

tests/test_resample.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,14 @@ def filt(x):
102102
)
103103

104104
assert dr.allclose(r1, r2)
105+
106+
# Test filtering a signal without changing its resolution
107+
@pytest.test_arrays('float, -jit, shape=(*)')
108+
def test07_convolve(t):
109+
x = t(1, 2, 10, 100)
110+
y = dr.convolve(x, 'linear', 1)
111+
assert dr.allclose(x, y)
112+
113+
y = dr.convolve(x, 'linear', 2)
114+
z = t((1+2*.5)/1.5, (1*.5+2+10*.5)/2, (2*.5+10+100*.5)/2, (100+10*.5)/1.5)
115+
assert dr.allclose(y, z)

0 commit comments

Comments
 (0)