Skip to content

Commit 49784de

Browse files
Added unwrap operator (#1133)
* Added unwrap operator Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 829ec0d commit 49784de

File tree

6 files changed

+343
-0
lines changed

6 files changed

+343
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
.. _unwrap_func:
2+
3+
unwrap
4+
======
5+
6+
Unwrap phase-like values by replacing jumps larger than a discontinuity with their period-complementary values.
7+
This function is not optimized for parallel performance.
8+
9+
This matches NumPy's ``unwrap`` characteristics:
10+
11+
- Operates along a selected axis (default: last axis)
12+
- Uses a configurable ``period`` (default: ``2*pi``)
13+
- Treats ``discont < period/2`` as ``period/2``
14+
15+
.. doxygenfunction:: unwrap
16+
17+
Examples
18+
~~~~~~~~
19+
20+
.. literalinclude:: ../../../../test/00_operators/unwrap_test.cu
21+
:language: cpp
22+
:start-after: example-begin unwrap-test-1
23+
:end-before: example-end unwrap-test-1
24+
:dedent:
25+

include/matx/operators/operators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
#include "matx/operators/trace.h"
118118
#include "matx/operators/transpose.h"
119119
#include "matx/operators/unique.h"
120+
#include "matx/operators/unwrap.h"
120121
#include "matx/operators/updownsample.h"
121122
#include "matx/operators/var.h"
122123
#include "matx/operators/zipvec.h"

include/matx/operators/unwrap.h

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2026, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#pragma once
34+
35+
#include "matx/core/type_utils.h"
36+
#include "matx/operators/base_operator.h"
37+
#include "matx/operators/scalar_internal.h"
38+
39+
namespace matx {
40+
namespace detail {
41+
template <typename OpA, typename MathType>
42+
class UnwrapOp : public BaseOp<UnwrapOp<OpA, MathType>> {
43+
public:
44+
using matxop = bool;
45+
using value_type = typename OpA::value_type;
46+
47+
__MATX_INLINE__ UnwrapOp(const OpA &op, int axis, MathType discont, MathType period)
48+
: op_(op), discont_(discont), period_(period), half_period_(period / static_cast<MathType>(2)) {
49+
static_assert(cuda::std::is_floating_point_v<MathType>,
50+
"unwrap() requires a floating-point input");
51+
static_assert(!is_complex_v<value_type>,
52+
"unwrap() does not support complex input");
53+
54+
MATX_ASSERT_STR(period_ > static_cast<MathType>(0), matxInvalidParameter,
55+
"unwrap period must be positive");
56+
57+
MATX_LOOP_UNROLL
58+
for (int i = 0; i < Rank(); i++) {
59+
sizes_[i] = op_.Size(i);
60+
}
61+
62+
if constexpr (Rank() > 0) {
63+
axis_ = axis;
64+
if (axis_ < 0) {
65+
axis_ += Rank();
66+
}
67+
MATX_ASSERT_STR(axis_ >= 0 && axis_ < Rank(), matxInvalidDim,
68+
"unwrap axis must be in range [-rank, rank-1]");
69+
}
70+
else {
71+
axis_ = 0;
72+
}
73+
74+
// Match NumPy semantics: discont values smaller than period/2 are treated
75+
// as period/2.
76+
if (discont_ < half_period_) {
77+
discont_ = half_period_;
78+
}
79+
}
80+
81+
__MATX_INLINE__ std::string str() const { return "unwrap(" + op_.str() + ")"; }
82+
83+
template <typename CapType, typename... Is>
84+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const {
85+
if constexpr (Rank() == 0) {
86+
return get_value<CapType>(op_);
87+
}
88+
else {
89+
constexpr index_t EPT = static_cast<index_t>(CapType::ept);
90+
auto get_lane_scalar = [](const auto &v, index_t lane) {
91+
if constexpr (CapType::ept == ElementsPerThread::ONE) {
92+
(void)lane;
93+
return static_cast<MathType>(v);
94+
}
95+
else {
96+
return static_cast<MathType>(v.data[lane]);
97+
}
98+
};
99+
100+
cuda::std::array<index_t, Rank()> idx{indices...};
101+
const index_t out_idx = idx[axis_];
102+
const auto cur = get_value<CapType>(op_, idx);
103+
cuda::std::array<MathType, static_cast<size_t>(EPT)> correction{};
104+
105+
if (out_idx != 0) {
106+
cuda::std::array<index_t, Rank()> seq_idx = idx;
107+
seq_idx[axis_] = 0;
108+
auto prev = get_value<CapType>(op_, seq_idx);
109+
const MathType neg_half_period = -half_period_;
110+
111+
for (index_t i = 1; i <= out_idx; i++) {
112+
seq_idx[axis_] = i;
113+
const auto next = get_value<CapType>(op_, seq_idx);
114+
115+
MATX_LOOP_UNROLL
116+
for (index_t lane = 0; lane < EPT; lane++) {
117+
const MathType next_s = get_lane_scalar(next, lane);
118+
const MathType prev_s = get_lane_scalar(prev, lane);
119+
const MathType delta = next_s - prev_s;
120+
121+
MathType delta_mod =
122+
static_cast<MathType>(scalar_internal_fmod(delta + half_period_, period_));
123+
if (delta_mod < static_cast<MathType>(0)) {
124+
delta_mod += period_;
125+
}
126+
delta_mod -= half_period_;
127+
128+
if (delta_mod == neg_half_period && delta > static_cast<MathType>(0)) {
129+
delta_mod = half_period_;
130+
}
131+
132+
MathType phase_correction = delta_mod - delta;
133+
if (cuda::std::abs(delta) < discont_) {
134+
phase_correction = static_cast<MathType>(0);
135+
}
136+
137+
correction[static_cast<size_t>(lane)] += phase_correction;
138+
}
139+
prev = next;
140+
}
141+
}
142+
143+
if constexpr (CapType::ept == ElementsPerThread::ONE) {
144+
return static_cast<value_type>(get_lane_scalar(cur, 0) + correction[0]);
145+
}
146+
else {
147+
Vector<value_type, EPT> out{};
148+
MATX_LOOP_UNROLL
149+
for (index_t lane = 0; lane < EPT; lane++) {
150+
out.data[lane] = static_cast<value_type>(
151+
get_lane_scalar(cur, lane) + correction[static_cast<size_t>(lane)]);
152+
}
153+
return out;
154+
}
155+
}
156+
}
157+
158+
template <typename... Is>
159+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const {
160+
return this->operator()<DefaultCapabilities>(indices...);
161+
}
162+
163+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() {
164+
return OpA::Rank();
165+
}
166+
167+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const {
168+
return sizes_[dim];
169+
}
170+
171+
template <typename ShapeType, typename Executor>
172+
__MATX_INLINE__ void PreRun(ShapeType &&shape, Executor &&ex) const noexcept {
173+
if constexpr (is_matx_op<OpA>()) {
174+
op_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
175+
}
176+
}
177+
178+
template <typename ShapeType, typename Executor>
179+
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept {
180+
if constexpr (is_matx_op<OpA>()) {
181+
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
182+
}
183+
}
184+
185+
template <OperatorCapability Cap, typename InType>
186+
__MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType &in) const {
187+
if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) {
188+
const auto my_cap =
189+
cuda::std::array<ElementsPerThread, 2>{ElementsPerThread::ONE, ElementsPerThread::ONE};
190+
return combine_capabilities<Cap>(my_cap, detail::get_operator_capability<Cap>(op_, in));
191+
}
192+
else {
193+
auto self_has_cap = capability_attributes<Cap>::default_value;
194+
return combine_capabilities<Cap>(self_has_cap, detail::get_operator_capability<Cap>(op_, in));
195+
}
196+
}
197+
198+
private:
199+
typename detail::base_type_t<OpA> op_;
200+
cuda::std::array<index_t, Rank()> sizes_;
201+
int axis_;
202+
MathType discont_;
203+
MathType period_;
204+
MathType half_period_;
205+
};
206+
} // namespace detail
207+
208+
/**
209+
* @brief Unwrap phase angles by correcting jumps greater than a discontinuity.
210+
*
211+
* This implementation follows NumPy's `unwrap` behavior, including support
212+
* for custom period and discont values.
213+
*
214+
* @tparam Op Input operator/tensor type
215+
* @param op Input operator
216+
* @param axis Axis to unwrap. Default is last axis (-1)
217+
* @param discont Maximum discontinuity between adjacent samples. Values lower
218+
* than `period / 2` are treated as `period / 2`.
219+
* @param period Complement period used to unwrap phase values. Default is 2*pi.
220+
*/
221+
template <typename Op>
222+
__MATX_INLINE__ auto unwrap(
223+
const Op &op, int axis = -1,
224+
detail::value_promote_t<typename Op::value_type> discont =
225+
static_cast<detail::value_promote_t<typename Op::value_type>>(-1),
226+
detail::value_promote_t<typename Op::value_type> period =
227+
static_cast<detail::value_promote_t<typename Op::value_type>>(
228+
cuda::std::numbers::pi_v<detail::value_promote_t<typename Op::value_type>> * 2)) {
229+
MATX_NVTX_START("unwrap(" + get_type_str(op) + ")", matx::MATX_NVTX_LOG_API)
230+
using math_type = detail::value_promote_t<typename Op::value_type>;
231+
const math_type period_in = static_cast<math_type>(period);
232+
const math_type default_discont = period_in / static_cast<math_type>(2);
233+
const math_type discont_in =
234+
(discont < static_cast<math_type>(0)) ? default_discont
235+
: static_cast<math_type>(discont);
236+
return detail::UnwrapOp<Op, math_type>(op, axis, discont_in, period_in);
237+
}
238+
239+
} // namespace matx

test/00_operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ set(OPERATOR_TEST_FILES
6767
toeplitz_test.cu
6868
transpose_test.cu
6969
trig_funcs_test.cu
70+
unwrap_test.cu
7071
updownsample_test.cu
7172
zipvec_test.cu
7273
)

test/00_operators/unwrap_test.cu

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "operator_test_types.hpp"
2+
#include "matx.h"
3+
#include "test_types.h"
4+
#include "utilities.h"
5+
6+
using namespace matx;
7+
using namespace matx::test;
8+
9+
TYPED_TEST(OperatorTestsFloatNonComplexNonHalfAllExecs, Unwrap)
10+
{
11+
MATX_ENTER_HANDLER();
12+
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
13+
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;
14+
15+
auto pb = std::make_unique<detail::MatXPybind>();
16+
pb->InitAndRunTVGenerator<TestType>("00_operators", "unwrap_operator", "run", {37, 11, 17});
17+
18+
ExecType exec{};
19+
auto in1 = make_tensor<TestType>({37});
20+
auto in2 = make_tensor<TestType>({11, 17});
21+
auto out1_default = make_tensor<TestType>({37});
22+
auto out1_period = make_tensor<TestType>({37});
23+
auto out2_axis1 = make_tensor<TestType>({11, 17});
24+
auto out2_axis0 = make_tensor<TestType>({11, 17});
25+
26+
pb->NumpyToTensorView(in1, "in1");
27+
pb->NumpyToTensorView(in2, "in2");
28+
29+
// example-begin unwrap-test-1
30+
(out1_default = unwrap(in1)).run(exec);
31+
// example-end unwrap-test-1
32+
exec.sync();
33+
MATX_TEST_ASSERT_COMPARE(pb, out1_default, "out1_default", 0.01);
34+
35+
(out1_period = unwrap(in1, -1, static_cast<TestType>(2.5), static_cast<TestType>(4.0))).run(exec);
36+
exec.sync();
37+
MATX_TEST_ASSERT_COMPARE(pb, out1_period, "out1_period", 0.01);
38+
39+
(out2_axis1 = unwrap(in2, 1)).run(exec);
40+
exec.sync();
41+
MATX_TEST_ASSERT_COMPARE(pb, out2_axis1, "out2_axis1", 0.01);
42+
43+
// discont < period/2 should behave like discont == period/2.
44+
(out2_axis0 = unwrap(in2, 0, static_cast<TestType>(1.0), static_cast<TestType>(6.0))).run(exec);
45+
exec.sync();
46+
MATX_TEST_ASSERT_COMPARE(pb, out2_axis0, "out2_axis0", 0.01);
47+
48+
MATX_EXIT_HANDLER();
49+
}

test/test_vectors/generators/00_operators.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,34 @@ def run(self) -> Dict[str, np.array]:
153153
'out3': sl.toeplitz(c, r2)
154154
}
155155

156+
class unwrap_operator:
157+
def __init__(self, dtype: str, size: List[int]):
158+
self.size = size
159+
self.dtype = dtype
160+
np.random.seed(1234)
161+
162+
def run(self) -> Dict[str, np.array]:
163+
n = self.size[0]
164+
m = self.size[1]
165+
k = self.size[2]
166+
dtype = np.dtype(self.dtype)
167+
168+
phase_1d = np.linspace(0.0, 7.0 * np.pi, n) + 0.2 * np.random.randn(n)
169+
in1 = np.angle(np.exp(1j * phase_1d)).astype(dtype)
170+
171+
phase_2d = np.linspace(0.0, 9.0 * np.pi, m * k).reshape((m, k))
172+
phase_2d = phase_2d + 0.25 * np.random.randn(m, k)
173+
in2 = np.angle(np.exp(1j * phase_2d)).astype(dtype)
174+
175+
return {
176+
'in1': in1,
177+
'in2': in2,
178+
'out1_default': np.unwrap(in1).astype(dtype),
179+
'out1_period': np.unwrap(in1, discont=2.5, period=4.0).astype(dtype),
180+
'out2_axis1': np.unwrap(in2, axis=1).astype(dtype),
181+
'out2_axis0': np.unwrap(in2, axis=0, discont=1.0, period=6.0).astype(dtype)
182+
}
183+
156184
class pwelch_operators:
157185
def __init__(self, dtype: str, cfg: Dict): #PWelchGeneratorCfg):
158186
self.dtype = dtype

0 commit comments

Comments
 (0)