Skip to content

Commit 298f8b1

Browse files
authored
feat: NxSignal.FindPeaks (#13)
* feat: start peak finding algos * feat: add argrel* functions * test: run doctests
1 parent 33dfcbd commit 298f8b1

File tree

3 files changed

+398
-1
lines changed

3 files changed

+398
-1
lines changed

lib/nx_signal/peak_finding.ex

Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
defmodule NxSignal.PeakFinding do
2+
@moduledoc """
3+
Peak finding algorithms.
4+
"""
5+
6+
import Nx.Defn
7+
import Nx, only: [u8: 1, s64: 1]
8+
9+
@doc """
10+
Finds a relative minimum along the selected `:axis`.
11+
12+
A relative minimum is defined by the element being greater
13+
than its neighbors along the axis `:axis`.
14+
15+
Returns a map in the following format:
16+
17+
%{
18+
indices: #Nx.Tensor<...>,
19+
valid_indices: #Nx.Tensor<...>
20+
}
21+
22+
* `:indices` - the `{n, rank}` tensor of indices.
23+
Contains `-1` as a placeholder for invalid indices.
24+
25+
* `:valid_indices` - the number of valid indices that lead the tensor.
26+
27+
## Options
28+
29+
* `:axis` - the axis along which to do comparisons. Defaults to 0.
30+
* `:order` - the number of neighbor samples considered for the
31+
comparison in each direction. Defaults to 1.
32+
33+
## Examples
34+
35+
iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0])
36+
iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x)
37+
iex> valid_indices
38+
#Nx.Tensor<
39+
u64
40+
2
41+
>
42+
iex> indices
43+
#Nx.Tensor<
44+
s64[8][1]
45+
[
46+
[1],
47+
[5],
48+
[-1],
49+
[-1],
50+
[-1],
51+
[-1],
52+
[-1],
53+
[-1]
54+
]
55+
>
56+
iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0)
57+
#Nx.Tensor<
58+
s64[2][1]
59+
[
60+
[1],
61+
[5]
62+
]
63+
>
64+
65+
For the same tensor in the previous example, we can use `:order` to check if
66+
the relative maxima are extrema in a wider neighborhood.
67+
68+
iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0])
69+
iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x, order: 3)
70+
iex> valid_indices
71+
#Nx.Tensor<
72+
u64
73+
1
74+
>
75+
iex> indices
76+
#Nx.Tensor<
77+
s64[8][1]
78+
[
79+
[1],
80+
[-1],
81+
[-1],
82+
[-1],
83+
[-1],
84+
[-1],
85+
[-1],
86+
[-1]
87+
]
88+
>
89+
iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0)
90+
#Nx.Tensor<
91+
s64[1][1]
92+
[
93+
[1]
94+
]
95+
>
96+
97+
We can also apply this function to tensors with a larger rank:
98+
99+
iex> x = Nx.tensor([[1, 2, 1, 2], [6, 2, 0, 0], [5, 3, 4, 4]])
100+
iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x)
101+
iex> valid_indices
102+
#Nx.Tensor<
103+
u64
104+
2
105+
>
106+
iex> indices[0..1]
107+
#Nx.Tensor<
108+
s64[2][2]
109+
[
110+
[1, 2],
111+
[1, 3]
112+
]
113+
>
114+
iex> %{indices: indices} = NxSignal.PeakFinding.argrelmin(x, axis: 1)
115+
iex> valid_indices
116+
#Nx.Tensor<
117+
u64
118+
2
119+
>
120+
iex> indices[0..1]
121+
#Nx.Tensor<
122+
s64[2][2]
123+
[
124+
[0, 2],
125+
[2, 1]
126+
]
127+
>
128+
129+
"""
130+
@doc type: :peak_finding
131+
defn argrelmin(data, opts \\ []) do
132+
opts = keyword!(opts, axis: 0, order: 1)
133+
argrelextrema(data, &Nx.less/2, opts)
134+
end
135+
136+
@doc """
137+
Finds a relative maximum along the selected `:axis`.
138+
139+
A relative maximum is defined by the element being greater
140+
than its neighbors along the axis `:axis`.
141+
142+
Returns a map in the following format:
143+
144+
%{
145+
indices: #Nx.Tensor<...>,
146+
valid_indices: #Nx.Tensor<...>
147+
}
148+
149+
* `:indices` - the `{n, rank}` tensor of indices.
150+
Contains `-1` as a placeholder for invalid indices.
151+
152+
* `:valid_indices` - the number of valid indices that lead the tensor.
153+
154+
## Options
155+
156+
* `:axis` - the axis along which to do comparisons. Defaults to 0.
157+
* `:order` - the number of neighbor samples considered for the
158+
comparison in each direction. Defaults to 1.
159+
160+
## Examples
161+
162+
iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0])
163+
iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x)
164+
iex> valid_indices
165+
#Nx.Tensor<
166+
u64
167+
2
168+
>
169+
iex> indices
170+
#Nx.Tensor<
171+
s64[8][1]
172+
[
173+
[3],
174+
[6],
175+
[-1],
176+
[-1],
177+
[-1],
178+
[-1],
179+
[-1],
180+
[-1]
181+
]
182+
>
183+
iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0)
184+
#Nx.Tensor<
185+
s64[2][1]
186+
[
187+
[3],
188+
[6]
189+
]
190+
>
191+
192+
For the same tensor in the previous example, we can use `:order` to check if
193+
the relative maxima are extrema in a wider neighborhood.
194+
195+
iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0])
196+
iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x, order: 3)
197+
iex> valid_indices
198+
#Nx.Tensor<
199+
u64
200+
1
201+
>
202+
iex> indices
203+
#Nx.Tensor<
204+
s64[8][1]
205+
[
206+
[3],
207+
[-1],
208+
[-1],
209+
[-1],
210+
[-1],
211+
[-1],
212+
[-1],
213+
[-1]
214+
]
215+
>
216+
iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0)
217+
#Nx.Tensor<
218+
s64[1][1]
219+
[
220+
[3]
221+
]
222+
>
223+
224+
We can also apply this function to tensors with a larger rank:
225+
226+
iex> x = Nx.tensor([[1, 2, 1, 2], [6, 2, 0, 0], [5, 3, 4, 4]])
227+
iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x)
228+
iex> valid_indices
229+
#Nx.Tensor<
230+
u64
231+
1
232+
>
233+
iex> indices[0]
234+
#Nx.Tensor<
235+
s64[2]
236+
[1, 0]
237+
>
238+
iex> %{indices: indices} = NxSignal.PeakFinding.argrelmax(x, axis: 1)
239+
iex> valid_indices
240+
#Nx.Tensor<
241+
u64
242+
1
243+
>
244+
iex> indices[0]
245+
#Nx.Tensor<
246+
s64[2]
247+
[0, 1]
248+
>
249+
250+
"""
251+
@doc type: :peak_finding
252+
defn argrelmax(data, opts \\ []) do
253+
opts = keyword!(opts, axis: 0, order: 1)
254+
argrelextrema(data, &Nx.greater/2, opts)
255+
end
256+
257+
@doc """
258+
Finds a relative extrema along the selected `:axis`.
259+
260+
A relative extremum is defined by the given `comparator_fn`
261+
function of arity 2 function that returns a boolean tensor.
262+
263+
This is the function upon which `&argrelmax/2` and `&argrelmin/2`
264+
are implemented.
265+
266+
Returns a map in the following format:
267+
268+
%{
269+
indices: #Nx.Tensor<...>,
270+
valid_indices: #Nx.Tensor<...>
271+
}
272+
273+
* `:indices` - the `{n, rank}` tensor of indices.
274+
Contains `-1` as a placeholder for invalid indices.
275+
276+
* `:valid_indices` - the number of valid indices that lead the tensor.
277+
278+
## Options
279+
280+
* `:axis` - the axis along which to do comparisons. Defaults to 0.
281+
* `:order` - the number of neighbor samples considered for the
282+
comparison in each direction. Defaults to 1.
283+
284+
## Examples
285+
286+
First, do read the examples on `argrelmax/2` keeping in mind that
287+
it is equivalent to `argrelextrema(&1, &Nx.greater/2, &2)`, as well
288+
as `argrelmin/2` which is equivalent to `argrelextrema(&1, &Nx.less/2, &2)`.
289+
290+
Having that in mind, we will expand on those concepts by using a custom function.
291+
For instance, we can change the definition of a relative maximum to one where
292+
a number is a relative maximum if it is greater than or equal to the double of its
293+
neighbors, as follows:
294+
295+
iex> comparator = fn x, y -> Nx.greater_equal(x, Nx.multiply(y, 2)) end
296+
iex> x = Nx.tensor([0, 1, 3, 2, 0, 1, 0, 0, 0, 2, 1])
297+
iex> result = NxSignal.PeakFinding.argrelextrema(x, comparator)
298+
iex> result.valid_indices
299+
#Nx.Tensor<
300+
u64
301+
3
302+
>
303+
iex> result.indices[0..2]
304+
#Nx.Tensor<
305+
s64[3][1]
306+
[
307+
[5],
308+
[7],
309+
[9]
310+
]
311+
>
312+
313+
Same applies for finding local minima. In the next example, we
314+
find all local minima (i.e. `&Nx.less/2`) that are
315+
different to the global minimum.
316+
317+
iex> x = Nx.tensor([0, 1, 0, 2, 1, 3, 0, 1])
318+
iex> global_minimum = Nx.reduce_min(x)
319+
iex> comparator = fn x, y ->
320+
...> x_not_global = Nx.not_equal(x, global_minimum)
321+
...> y_not_global = Nx.not_equal(y, global_minimum)
322+
...> both_not_global = Nx.logical_and(x_not_global, y_not_global)
323+
...> Nx.logical_and(Nx.less(x, y), both_not_global)
324+
...> end
325+
iex> result = NxSignal.PeakFinding.argrelextrema(x, comparator)
326+
iex> result.valid_indices
327+
#Nx.Tensor<
328+
u64
329+
1
330+
>
331+
iex> result.indices[0..0]
332+
#Nx.Tensor<
333+
s64[1][1]
334+
[
335+
[4]
336+
]
337+
>
338+
"""
339+
@doc type: :peak_finding
340+
defn argrelextrema(data, comparator_fn, opts \\ []) do
341+
opts = keyword!(opts, axis: 0, order: 1)
342+
343+
data
344+
|> boolrelextrema(comparator_fn, opts)
345+
|> nonzero()
346+
end
347+
348+
defnp boolrelextrema(data, comparator_fn, opts \\ []) do
349+
axis = opts[:axis]
350+
order = opts[:order]
351+
locs = Nx.iota({Nx.axis_size(data, axis)})
352+
353+
ones = Nx.broadcast(u8(1), data.shape)
354+
[ones, _] = Nx.broadcast_vectors([ones, data])
355+
356+
{results, _} =
357+
while {results = ones, {data, locs, halt = u8(0), shift = s64(1)}},
358+
not halt and shift < order + 1 do
359+
plus = Nx.take(data, Nx.clip(locs + shift, 0, Nx.size(locs) - 1), axis: axis)
360+
minus = Nx.take(data, Nx.clip(locs - shift, 0, Nx.size(locs) - 1), axis: axis)
361+
results = comparator_fn.(data, plus) and results
362+
results = comparator_fn.(data, minus) and results
363+
364+
{results, {data, locs, not Nx.any(results), shift + 1}}
365+
end
366+
367+
results
368+
end
369+
370+
deftransformp nonzero(data) do
371+
flat_data = Nx.reshape(data, {:auto, 1})
372+
373+
indices =
374+
for axis <- 0..(Nx.rank(data) - 1),
375+
reduce: Nx.broadcast(0, {Nx.axis_size(flat_data, 0), Nx.rank(data)}) do
376+
%{shape: {n, _}} = indices ->
377+
iota = data.shape |> Nx.iota(axis: axis) |> Nx.reshape({n, 1})
378+
Nx.put_slice(indices, [0, axis], iota)
379+
end
380+
381+
indices_with_mask =
382+
Nx.select(
383+
Nx.broadcast(flat_data, indices.shape),
384+
indices,
385+
Nx.broadcast(-1, indices.shape)
386+
)
387+
388+
order = Nx.argsort(Nx.squeeze(flat_data, axes: [1]), axis: 0, direction: :desc)
389+
390+
%{indices: Nx.take(indices_with_mask, order), valid_indices: Nx.sum(flat_data)}
391+
end
392+
end

test/nx_signal/peak_finding_test.exs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
defmodule NxSignal.PeakFindingTest do
2+
use NxSignal.Case, async: true
3+
doctest NxSignal.PeakFinding
4+
end

0 commit comments

Comments
 (0)