Skip to content

Commit 3bbe87e

Browse files
authored
Add hanning window function (ml-explore#3124)
1 parent e226af7 commit 3bbe87e

File tree

4 files changed

+50
-0
lines changed

4 files changed

+50
-0
lines changed

mlx/ops.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,6 +2311,19 @@ array argmax(
23112311
return out;
23122312
}
23132313

2314+
array hanning(int M, StreamOrDevice s /* = {} */) {
2315+
if (M < 1) {
2316+
return array({});
2317+
}
2318+
if (M == 1) {
2319+
return ones({1}, float32, s);
2320+
}
2321+
2322+
auto n = arange(0, M, float32, s);
2323+
array factor(M_PI / (M - 1), float32);
2324+
return square(sin(multiply(factor, n, s), s), s);
2325+
}
2326+
23142327
/** Returns a sorted copy of the flattened array. */
23152328
array sort(const array& a, StreamOrDevice s /* = {} */) {
23162329
int size = a.size();

mlx/ops.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,9 @@ min(const array& a,
666666
MLX_API array
667667
min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});
668668

669+
/** Returns the Hanning window of size M. */
670+
MLX_API array hanning(int M, StreamOrDevice s = {});
671+
669672
/** Returns the index of the minimum value in the array. */
670673
MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
671674
inline array argmin(const array& a, StreamOrDevice s = {}) {

python/src/ops.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,28 @@ void init_ops(nb::module_& m) {
14281428
"stream"_a = nb::none(),
14291429
nb::sig(
14301430
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
1431+
m.def(
1432+
"hanning",
1433+
&mlx::core::hanning,
1434+
"M"_a,
1435+
nb::kw_only(),
1436+
"stream"_a = nb::none(),
1437+
R"pbdoc(
1438+
Return the Hanning window.
1439+
1440+
The Hanning window is a taper formed by using a weighted cosine.
1441+
1442+
.. math::
1443+
w(n) = 0.5 - 0.5 \cos\left(\frac{2\pi n}{M-1}\right)
1444+
\qquad 0 \le n \le M-1
1445+
1446+
Args:
1447+
M (int): Number of points in the output window.
1448+
1449+
Returns:
1450+
array: The window, with the maximum value normalized to one (the value one
1451+
appears only if the number of samples is odd).
1452+
)pbdoc");
14311453
m.def(
14321454
"linspace",
14331455
[](Scalar start,

python/tests/test_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,18 @@ def test_arange_corner_cases_cast(self):
14501450
expected = [0]
14511451
self.assertListEqual(a.tolist(), expected)
14521452

1453+
def test_hanning_general(self):
1454+
a = mx.hanning(10)
1455+
expected = np.hanning(10)
1456+
self.assertTrue(np.allclose(a, expected, atol=1e-5))
1457+
1458+
a = mx.hanning(1)
1459+
self.assertEqual(a.item(), 1.0)
1460+
1461+
a = mx.hanning(0)
1462+
self.assertEqual(a.size, 0)
1463+
self.assertEqual(a.dtype, mx.float32)
1464+
14531465
def test_unary_ops(self):
14541466
def test_ops(npop, mlxop, x, y, atol, rtol):
14551467
r_np = npop(x)

0 commit comments

Comments
 (0)