|
| 1 | +defmodule NxSignal.PeakFinding do |
| 2 | + @moduledoc """ |
| 3 | + Peak finding algorithms. |
| 4 | + """ |
| 5 | + |
| 6 | + import Nx.Defn |
| 7 | + import Nx, only: [u8: 1, s64: 1] |
| 8 | + |
| 9 | + @doc """ |
| 10 | + Finds a relative minimum along the selected `:axis`. |
| 11 | +
|
| 12 | + A relative minimum is defined by the element being greater |
| 13 | + than its neighbors along the axis `:axis`. |
| 14 | +
|
| 15 | + Returns a map in the following format: |
| 16 | +
|
| 17 | + %{ |
| 18 | + indices: #Nx.Tensor<...>, |
| 19 | + valid_indices: #Nx.Tensor<...> |
| 20 | + } |
| 21 | +
|
| 22 | + * `:indices` - the `{n, rank}` tensor of indices. |
| 23 | + Contains `-1` as a placeholder for invalid indices. |
| 24 | +
|
| 25 | + * `:valid_indices` - the number of valid indices that lead the tensor. |
| 26 | +
|
| 27 | + ## Options |
| 28 | +
|
| 29 | + * `:axis` - the axis along which to do comparisons. Defaults to 0. |
| 30 | + * `:order` - the number of neighbor samples considered for the |
| 31 | + comparison in each direction. Defaults to 1. |
| 32 | +
|
| 33 | + ## Examples |
| 34 | +
|
| 35 | + iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) |
| 36 | + iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x) |
| 37 | + iex> valid_indices |
| 38 | + #Nx.Tensor< |
| 39 | + u64 |
| 40 | + 2 |
| 41 | + > |
| 42 | + iex> indices |
| 43 | + #Nx.Tensor< |
| 44 | + s64[8][1] |
| 45 | + [ |
| 46 | + [1], |
| 47 | + [5], |
| 48 | + [-1], |
| 49 | + [-1], |
| 50 | + [-1], |
| 51 | + [-1], |
| 52 | + [-1], |
| 53 | + [-1] |
| 54 | + ] |
| 55 | + > |
| 56 | + iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) |
| 57 | + #Nx.Tensor< |
| 58 | + s64[2][1] |
| 59 | + [ |
| 60 | + [1], |
| 61 | + [5] |
| 62 | + ] |
| 63 | + > |
| 64 | +
|
| 65 | + For the same tensor in the previous example, we can use `:order` to check if |
| 66 | + the relative maxima are extrema in a wider neighborhood. |
| 67 | +
|
| 68 | + iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) |
| 69 | + iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x, order: 3) |
| 70 | + iex> valid_indices |
| 71 | + #Nx.Tensor< |
| 72 | + u64 |
| 73 | + 1 |
| 74 | + > |
| 75 | + iex> indices |
| 76 | + #Nx.Tensor< |
| 77 | + s64[8][1] |
| 78 | + [ |
| 79 | + [1], |
| 80 | + [-1], |
| 81 | + [-1], |
| 82 | + [-1], |
| 83 | + [-1], |
| 84 | + [-1], |
| 85 | + [-1], |
| 86 | + [-1] |
| 87 | + ] |
| 88 | + > |
| 89 | + iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) |
| 90 | + #Nx.Tensor< |
| 91 | + s64[1][1] |
| 92 | + [ |
| 93 | + [1] |
| 94 | + ] |
| 95 | + > |
| 96 | +
|
| 97 | + We can also apply this function to tensors with a larger rank: |
| 98 | +
|
| 99 | + iex> x = Nx.tensor([[1, 2, 1, 2], [6, 2, 0, 0], [5, 3, 4, 4]]) |
| 100 | + iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x) |
| 101 | + iex> valid_indices |
| 102 | + #Nx.Tensor< |
| 103 | + u64 |
| 104 | + 2 |
| 105 | + > |
| 106 | + iex> indices[0..1] |
| 107 | + #Nx.Tensor< |
| 108 | + s64[2][2] |
| 109 | + [ |
| 110 | + [1, 2], |
| 111 | + [1, 3] |
| 112 | + ] |
| 113 | + > |
| 114 | + iex> %{indices: indices} = NxSignal.PeakFinding.argrelmin(x, axis: 1) |
| 115 | + iex> valid_indices |
| 116 | + #Nx.Tensor< |
| 117 | + u64 |
| 118 | + 2 |
| 119 | + > |
| 120 | + iex> indices[0..1] |
| 121 | + #Nx.Tensor< |
| 122 | + s64[2][2] |
| 123 | + [ |
| 124 | + [0, 2], |
| 125 | + [2, 1] |
| 126 | + ] |
| 127 | + > |
| 128 | +
|
| 129 | + """ |
| 130 | + @doc type: :peak_finding |
| 131 | + defn argrelmin(data, opts \\ []) do |
| 132 | + opts = keyword!(opts, axis: 0, order: 1) |
| 133 | + argrelextrema(data, &Nx.less/2, opts) |
| 134 | + end |
| 135 | + |
| 136 | + @doc """ |
| 137 | + Finds a relative maximum along the selected `:axis`. |
| 138 | +
|
| 139 | + A relative maximum is defined by the element being greater |
| 140 | + than its neighbors along the axis `:axis`. |
| 141 | +
|
| 142 | + Returns a map in the following format: |
| 143 | +
|
| 144 | + %{ |
| 145 | + indices: #Nx.Tensor<...>, |
| 146 | + valid_indices: #Nx.Tensor<...> |
| 147 | + } |
| 148 | +
|
| 149 | + * `:indices` - the `{n, rank}` tensor of indices. |
| 150 | + Contains `-1` as a placeholder for invalid indices. |
| 151 | +
|
| 152 | + * `:valid_indices` - the number of valid indices that lead the tensor. |
| 153 | +
|
| 154 | + ## Options |
| 155 | +
|
| 156 | + * `:axis` - the axis along which to do comparisons. Defaults to 0. |
| 157 | + * `:order` - the number of neighbor samples considered for the |
| 158 | + comparison in each direction. Defaults to 1. |
| 159 | +
|
| 160 | + ## Examples |
| 161 | +
|
| 162 | + iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) |
| 163 | + iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x) |
| 164 | + iex> valid_indices |
| 165 | + #Nx.Tensor< |
| 166 | + u64 |
| 167 | + 2 |
| 168 | + > |
| 169 | + iex> indices |
| 170 | + #Nx.Tensor< |
| 171 | + s64[8][1] |
| 172 | + [ |
| 173 | + [3], |
| 174 | + [6], |
| 175 | + [-1], |
| 176 | + [-1], |
| 177 | + [-1], |
| 178 | + [-1], |
| 179 | + [-1], |
| 180 | + [-1] |
| 181 | + ] |
| 182 | + > |
| 183 | + iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) |
| 184 | + #Nx.Tensor< |
| 185 | + s64[2][1] |
| 186 | + [ |
| 187 | + [3], |
| 188 | + [6] |
| 189 | + ] |
| 190 | + > |
| 191 | +
|
| 192 | + For the same tensor in the previous example, we can use `:order` to check if |
| 193 | + the relative maxima are extrema in a wider neighborhood. |
| 194 | +
|
| 195 | + iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) |
| 196 | + iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x, order: 3) |
| 197 | + iex> valid_indices |
| 198 | + #Nx.Tensor< |
| 199 | + u64 |
| 200 | + 1 |
| 201 | + > |
| 202 | + iex> indices |
| 203 | + #Nx.Tensor< |
| 204 | + s64[8][1] |
| 205 | + [ |
| 206 | + [3], |
| 207 | + [-1], |
| 208 | + [-1], |
| 209 | + [-1], |
| 210 | + [-1], |
| 211 | + [-1], |
| 212 | + [-1], |
| 213 | + [-1] |
| 214 | + ] |
| 215 | + > |
| 216 | + iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) |
| 217 | + #Nx.Tensor< |
| 218 | + s64[1][1] |
| 219 | + [ |
| 220 | + [3] |
| 221 | + ] |
| 222 | + > |
| 223 | +
|
| 224 | + We can also apply this function to tensors with a larger rank: |
| 225 | +
|
| 226 | + iex> x = Nx.tensor([[1, 2, 1, 2], [6, 2, 0, 0], [5, 3, 4, 4]]) |
| 227 | + iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x) |
| 228 | + iex> valid_indices |
| 229 | + #Nx.Tensor< |
| 230 | + u64 |
| 231 | + 1 |
| 232 | + > |
| 233 | + iex> indices[0] |
| 234 | + #Nx.Tensor< |
| 235 | + s64[2] |
| 236 | + [1, 0] |
| 237 | + > |
| 238 | + iex> %{indices: indices} = NxSignal.PeakFinding.argrelmax(x, axis: 1) |
| 239 | + iex> valid_indices |
| 240 | + #Nx.Tensor< |
| 241 | + u64 |
| 242 | + 1 |
| 243 | + > |
| 244 | + iex> indices[0] |
| 245 | + #Nx.Tensor< |
| 246 | + s64[2] |
| 247 | + [0, 1] |
| 248 | + > |
| 249 | +
|
| 250 | + """ |
| 251 | + @doc type: :peak_finding |
| 252 | + defn argrelmax(data, opts \\ []) do |
| 253 | + opts = keyword!(opts, axis: 0, order: 1) |
| 254 | + argrelextrema(data, &Nx.greater/2, opts) |
| 255 | + end |
| 256 | + |
| 257 | + @doc """ |
| 258 | + Finds a relative extrema along the selected `:axis`. |
| 259 | +
|
| 260 | + A relative extremum is defined by the given `comparator_fn` |
| 261 | + function of arity 2 function that returns a boolean tensor. |
| 262 | +
|
| 263 | + This is the function upon which `&argrelmax/2` and `&argrelmin/2` |
| 264 | + are implemented. |
| 265 | +
|
| 266 | + Returns a map in the following format: |
| 267 | +
|
| 268 | + %{ |
| 269 | + indices: #Nx.Tensor<...>, |
| 270 | + valid_indices: #Nx.Tensor<...> |
| 271 | + } |
| 272 | +
|
| 273 | + * `:indices` - the `{n, rank}` tensor of indices. |
| 274 | + Contains `-1` as a placeholder for invalid indices. |
| 275 | +
|
| 276 | + * `:valid_indices` - the number of valid indices that lead the tensor. |
| 277 | +
|
| 278 | + ## Options |
| 279 | +
|
| 280 | + * `:axis` - the axis along which to do comparisons. Defaults to 0. |
| 281 | + * `:order` - the number of neighbor samples considered for the |
| 282 | + comparison in each direction. Defaults to 1. |
| 283 | +
|
| 284 | + ## Examples |
| 285 | +
|
| 286 | + First, do read the examples on `argrelmax/2` keeping in mind that |
| 287 | + it is equivalent to `argrelextrema(&1, &Nx.greater/2, &2)`, as well |
| 288 | + as `argrelmin/2` which is equivalent to `argrelextrema(&1, &Nx.less/2, &2)`. |
| 289 | +
|
| 290 | + Having that in mind, we will expand on those concepts by using a custom function. |
| 291 | + For instance, we can change the definition of a relative maximum to one where |
| 292 | + a number is a relative maximum if it is greater than or equal to the double of its |
| 293 | + neighbors, as follows: |
| 294 | +
|
| 295 | + iex> comparator = fn x, y -> Nx.greater_equal(x, Nx.multiply(y, 2)) end |
| 296 | + iex> x = Nx.tensor([0, 1, 3, 2, 0, 1, 0, 0, 0, 2, 1]) |
| 297 | + iex> result = NxSignal.PeakFinding.argrelextrema(x, comparator) |
| 298 | + iex> result.valid_indices |
| 299 | + #Nx.Tensor< |
| 300 | + u64 |
| 301 | + 3 |
| 302 | + > |
| 303 | + iex> result.indices[0..2] |
| 304 | + #Nx.Tensor< |
| 305 | + s64[3][1] |
| 306 | + [ |
| 307 | + [5], |
| 308 | + [7], |
| 309 | + [9] |
| 310 | + ] |
| 311 | + > |
| 312 | +
|
| 313 | + Same applies for finding local minima. In the next example, we |
| 314 | + find all local minima (i.e. `&Nx.less/2`) that are |
| 315 | + different to the global minimum. |
| 316 | +
|
| 317 | + iex> x = Nx.tensor([0, 1, 0, 2, 1, 3, 0, 1]) |
| 318 | + iex> global_minimum = Nx.reduce_min(x) |
| 319 | + iex> comparator = fn x, y -> |
| 320 | + ...> x_not_global = Nx.not_equal(x, global_minimum) |
| 321 | + ...> y_not_global = Nx.not_equal(y, global_minimum) |
| 322 | + ...> both_not_global = Nx.logical_and(x_not_global, y_not_global) |
| 323 | + ...> Nx.logical_and(Nx.less(x, y), both_not_global) |
| 324 | + ...> end |
| 325 | + iex> result = NxSignal.PeakFinding.argrelextrema(x, comparator) |
| 326 | + iex> result.valid_indices |
| 327 | + #Nx.Tensor< |
| 328 | + u64 |
| 329 | + 1 |
| 330 | + > |
| 331 | + iex> result.indices[0..0] |
| 332 | + #Nx.Tensor< |
| 333 | + s64[1][1] |
| 334 | + [ |
| 335 | + [4] |
| 336 | + ] |
| 337 | + > |
| 338 | + """ |
| 339 | + @doc type: :peak_finding |
| 340 | + defn argrelextrema(data, comparator_fn, opts \\ []) do |
| 341 | + opts = keyword!(opts, axis: 0, order: 1) |
| 342 | + |
| 343 | + data |
| 344 | + |> boolrelextrema(comparator_fn, opts) |
| 345 | + |> nonzero() |
| 346 | + end |
| 347 | + |
| 348 | + defnp boolrelextrema(data, comparator_fn, opts \\ []) do |
| 349 | + axis = opts[:axis] |
| 350 | + order = opts[:order] |
| 351 | + locs = Nx.iota({Nx.axis_size(data, axis)}) |
| 352 | + |
| 353 | + ones = Nx.broadcast(u8(1), data.shape) |
| 354 | + [ones, _] = Nx.broadcast_vectors([ones, data]) |
| 355 | + |
| 356 | + {results, _} = |
| 357 | + while {results = ones, {data, locs, halt = u8(0), shift = s64(1)}}, |
| 358 | + not halt and shift < order + 1 do |
| 359 | + plus = Nx.take(data, Nx.clip(locs + shift, 0, Nx.size(locs) - 1), axis: axis) |
| 360 | + minus = Nx.take(data, Nx.clip(locs - shift, 0, Nx.size(locs) - 1), axis: axis) |
| 361 | + results = comparator_fn.(data, plus) and results |
| 362 | + results = comparator_fn.(data, minus) and results |
| 363 | + |
| 364 | + {results, {data, locs, not Nx.any(results), shift + 1}} |
| 365 | + end |
| 366 | + |
| 367 | + results |
| 368 | + end |
| 369 | + |
| 370 | + deftransformp nonzero(data) do |
| 371 | + flat_data = Nx.reshape(data, {:auto, 1}) |
| 372 | + |
| 373 | + indices = |
| 374 | + for axis <- 0..(Nx.rank(data) - 1), |
| 375 | + reduce: Nx.broadcast(0, {Nx.axis_size(flat_data, 0), Nx.rank(data)}) do |
| 376 | + %{shape: {n, _}} = indices -> |
| 377 | + iota = data.shape |> Nx.iota(axis: axis) |> Nx.reshape({n, 1}) |
| 378 | + Nx.put_slice(indices, [0, axis], iota) |
| 379 | + end |
| 380 | + |
| 381 | + indices_with_mask = |
| 382 | + Nx.select( |
| 383 | + Nx.broadcast(flat_data, indices.shape), |
| 384 | + indices, |
| 385 | + Nx.broadcast(-1, indices.shape) |
| 386 | + ) |
| 387 | + |
| 388 | + order = Nx.argsort(Nx.squeeze(flat_data, axes: [1]), axis: 0, direction: :desc) |
| 389 | + |
| 390 | + %{indices: Nx.take(indices_with_mask, order), valid_indices: Nx.sum(flat_data)} |
| 391 | + end |
| 392 | +end |
0 commit comments