Skip to content

Commit 4f1ee15

Browse files
Optimise as_windowed and apply :reflect padding to the whole input (#17)
1 parent 298f8b1 commit 4f1ee15

File tree

1 file changed

+30
-107
lines changed

1 file changed

+30
-107
lines changed

lib/nx_signal.ex

Lines changed: 30 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,10 @@ defmodule NxSignal do
172172
173173
* `:window_length` - the number of samples in a window
174174
* `:stride` - The number of samples to skip between windows. Defaults to `1`.
175-
* `:padding` - A can be `:reflect` or a valid padding as per `Nx.pad/3` over the
176-
input tensor's shape. Defaults to `:valid`. If `:reflect` or `:zeros`, the first window will be centered
177-
at the start of the signal. For `:reflect`, each incomplete window will be reflected as if it was
178-
periodic (see examples for `as_windowed/2`). For `:zeros`, each incomplete window will be zero-padded.
175+
* `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the
176+
input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same`, the first window will be centered
177+
at the start of the signal. The padding is applied for the whole input, rather than individual
178+
windows. For `:zeros`, effectively each incomplete window will be zero-padded.
179179
180180
## Examples
181181
@@ -219,27 +219,29 @@ defmodule NxSignal do
219219
iex> t = Nx.iota({7});
220220
iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1)
221221
#Nx.Tensor<
222-
s64[7][6]
222+
s64[8][6]
223223
[
224-
[1, 2, 1, 0, 1, 2],
224+
[3, 2, 1, 0, 1, 2],
225225
[2, 1, 0, 1, 2, 3],
226226
[1, 0, 1, 2, 3, 4],
227227
[0, 1, 2, 3, 4, 5],
228228
[1, 2, 3, 4, 5, 6],
229229
[2, 3, 4, 5, 6, 5],
230-
[3, 4, 5, 6, 5, 4]
230+
[3, 4, 5, 6, 5, 4],
231+
[4, 5, 6, 5, 4, 3]
231232
]
232233
>
233234
234235
iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2)
235236
#Nx.Tensor<
236-
s64[5][6]
237+
s64[6][6]
237238
[
238-
[1, 2, 1, 0, 1, 2],
239+
[3, 2, 1, 0, 1, 2],
239240
[1, 0, 1, 2, 3, 4],
240241
[1, 2, 3, 4, 5, 6],
241242
[3, 4, 5, 6, 7, 8],
242-
[5, 6, 7, 8, 9, 8]
243+
[5, 6, 7, 8, 9, 8],
244+
[7, 8, 9, 8, 7, 6]
243245
]
244246
>
245247
"""
@@ -257,7 +259,7 @@ defmodule NxSignal do
257259

258260
as_windowed_parse_non_reflect_opts(
259261
shape,
260-
Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2) - 1}])
262+
Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2)}])
261263
)
262264
end
263265

@@ -333,114 +335,34 @@ defmodule NxSignal do
333335
{window_length, stride, padding, output_shape} =
334336
as_windowed_parse_non_reflect_opts(Nx.shape(tensor), opts)
335337

336-
output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
337-
{num_windows, _} = Nx.shape(output)
338-
339-
index_template =
340-
Nx.concatenate([Nx.broadcast(0, {window_length, 1}), Nx.iota({window_length, 1})], axis: 1)
341-
342-
{output, _, _, _, _} =
343-
while {output, i = 0, current_window = 0, t = Nx.pad(tensor, 0, padding), index_template},
344-
current_window < num_windows do
345-
indices = index_template + Nx.stack([current_window, 0])
346-
updates = t |> Nx.slice([i], [window_length]) |> Nx.flatten()
347-
348-
updated = Nx.indexed_add(output, indices, updates)
338+
tensor = Nx.pad(tensor, 0, padding)
349339

350-
{updated, i + stride, current_window + 1, t, index_template}
351-
end
352-
353-
output
340+
as_windowed_apply(tensor, stride, output_shape, window_length)
354341
end
355342

356343
defnp as_windowed_reflect_padding(tensor, opts \\ []) do
357344
# current implementation only supports windowing 1D tensors
358345
{window_length, stride, _padding, output_shape} =
359346
as_windowed_parse_reflect_opts(Nx.shape(tensor), opts)
360347

361-
output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
362-
{num_windows, _} = Nx.shape(output)
363-
364-
index_template =
365-
Nx.concatenate([Nx.broadcast(0, {window_length, 1}), Nx.iota({window_length, 1})], axis: 1)
366-
367-
leading_window_indices = generate_leading_window_indices(window_length, stride)
368-
369-
trailing_window_indices =
370-
generate_trailing_window_indices(Nx.size(tensor), window_length, stride)
371-
372-
half_window = div(window_length - 1, 2) + 1
373-
374-
{output, _, _, _, _} =
375-
while {output, i = 0, current_window = 0, t = tensor, index_template},
376-
current_window < num_windows do
377-
# Here windows are centered at the current index
378-
379-
cond do
380-
i < half_window ->
381-
# We're indexing before we have a full window on the left
382-
383-
window = Nx.take(t, leading_window_indices[i])
384-
385-
indices = index_template + Nx.stack([current_window, 0])
386-
updated = Nx.indexed_add(output, indices, window)
387-
388-
{updated, i + stride, current_window + 1, t, index_template}
389-
390-
i > Nx.size(t) - half_window ->
391-
# We're indexing after the last full window on the right
392-
window = Nx.take(t, trailing_window_indices[i - (Nx.size(t) - half_window + 1)])
393-
394-
indices = index_template + Nx.stack([current_window, 0])
395-
updated = Nx.indexed_add(output, indices, window)
396-
397-
{updated, i + stride, current_window + 1, t, index_template}
398-
399-
true ->
400-
# Case where we can index a full window
401-
indices = index_template + Nx.stack([current_window, 0])
402-
updates = t |> Nx.slice([i - half_window], [window_length]) |> Nx.flatten()
403-
404-
updated = Nx.indexed_add(output, indices, updates)
405-
406-
{updated, i + stride, current_window + 1, t, index_template}
407-
end
408-
end
409-
410-
# Now we need to handle the tail-end of the windows,
411-
# since they are currently all the same value. We want to apply the tapering-off
412-
# like we did with the initial windows.
413-
414-
output
415-
end
416-
417-
deftransformp generate_leading_window_indices(window_length, stride) do
418348
half_window = div(window_length, 2)
349+
tensor = Nx.reflect(tensor, padding_config: [{half_window, half_window}])
419350

420-
for offset <- 0..half_window//stride do
421-
partial_length = offset + half_window
422-
padding_length = window_length - partial_length
423-
424-
{partial_length}
425-
|> Nx.iota()
426-
|> Nx.reflect(padding_config: [{padding_length, 0}])
427-
end
428-
|> Nx.stack()
351+
as_windowed_apply(tensor, stride, output_shape, window_length)
429352
end
430353

431-
deftransformp generate_trailing_window_indices(tensor_size, window_length, stride) do
432-
min_index = tensor_size - window_length + 1
354+
defnp as_windowed_apply(tensor, stride, output_shape, window_length) do
355+
output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
356+
{num_windows, _} = Nx.shape(output)
433357

434-
for {offset, add} <- Enum.with_index(min_index..(tensor_size - 1)//stride) do
435-
partial_length = tensor_size - offset
436-
padding_length = window_length - partial_length
358+
{output, _, _, _} =
359+
while {output, i = 0, current_window = 0, t = tensor}, current_window < num_windows do
360+
window = t |> Nx.slice([i], [window_length])
361+
updated = Nx.put_slice(output, [current_window, 0], Nx.new_axis(window, 0))
362+
{updated, i + stride, current_window + 1, t}
363+
end
437364

438-
{partial_length}
439-
|> Nx.iota()
440-
|> Nx.add(min_index + add - rem(window_length, 2))
441-
|> Nx.reflect(padding_config: [{0, padding_length}])
442-
end
443-
|> Nx.stack()
365+
output
444366
end
445367

446368
@doc """
@@ -548,15 +470,16 @@ defmodule NxSignal do
548470
iex> Nx.axis_size(z, :frequencies)
549471
16
550472
iex> Nx.axis_size(z, :frames)
551-
5
473+
6
552474
iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4)
553475
#Nx.Tensor<
554-
f32[frames: 5][mel: 4]
476+
f32[frames: 6][mel: 4]
555477
[
556478
[0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825],
557479
[0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537],
558480
[0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198],
559481
[0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989],
482+
[0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721],
560483
[0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721]
561484
]
562485
>

0 commit comments

Comments
 (0)