@@ -172,10 +172,10 @@ defmodule NxSignal do
172
172
173
173
* `:window_length` - the number of samples in a window
174
174
* `:stride` - The number of samples to skip between windows. Defaults to `1`.
175
- * `:padding` - A can be `:reflect` or a valid padding as per `Nx.pad/3` over the
176
- input tensor's shape. Defaults to `:valid`. If `:reflect` or `:zeros `, the first window will be centered
177
- at the start of the signal. For `:reflect`, each incomplete window will be reflected as if it was
178
- periodic (see examples for `as_windowed/2`) . For `:zeros`, each incomplete window will be zero-padded.
175
+ * `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the
176
+ input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same `, the first window will be centered
177
+ at the start of the signal. The padding is applied for the whole input, rather than individual
178
+ windows . For `:zeros`, effectively each incomplete window will be zero-padded.
179
179
180
180
## Examples
181
181
@@ -219,27 +219,29 @@ defmodule NxSignal do
219
219
iex> t = Nx.iota({7});
220
220
iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1)
221
221
#Nx.Tensor<
222
- s64[7 ][6]
222
+ s64[8 ][6]
223
223
[
224
- [1 , 2, 1, 0, 1, 2],
224
+ [3 , 2, 1, 0, 1, 2],
225
225
[2, 1, 0, 1, 2, 3],
226
226
[1, 0, 1, 2, 3, 4],
227
227
[0, 1, 2, 3, 4, 5],
228
228
[1, 2, 3, 4, 5, 6],
229
229
[2, 3, 4, 5, 6, 5],
230
- [3, 4, 5, 6, 5, 4]
230
+ [3, 4, 5, 6, 5, 4],
231
+ [4, 5, 6, 5, 4, 3]
231
232
]
232
233
>
233
234
234
235
iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2)
235
236
#Nx.Tensor<
236
- s64[5 ][6]
237
+ s64[6 ][6]
237
238
[
238
- [1 , 2, 1, 0, 1, 2],
239
+ [3 , 2, 1, 0, 1, 2],
239
240
[1, 0, 1, 2, 3, 4],
240
241
[1, 2, 3, 4, 5, 6],
241
242
[3, 4, 5, 6, 7, 8],
242
- [5, 6, 7, 8, 9, 8]
243
+ [5, 6, 7, 8, 9, 8],
244
+ [7, 8, 9, 8, 7, 6]
243
245
]
244
246
>
245
247
"""
@@ -257,7 +259,7 @@ defmodule NxSignal do
257
259
258
260
as_windowed_parse_non_reflect_opts (
259
261
shape ,
260
- Keyword . put ( opts , :padding , [ { div ( window_length , 2 ) , div ( window_length , 2 ) - 1 } ] )
262
+ Keyword . put ( opts , :padding , [ { div ( window_length , 2 ) , div ( window_length , 2 ) } ] )
261
263
)
262
264
end
263
265
@@ -333,114 +335,34 @@ defmodule NxSignal do
333
335
{ window_length , stride , padding , output_shape } =
334
336
as_windowed_parse_non_reflect_opts ( Nx . shape ( tensor ) , opts )
335
337
336
- output = Nx . broadcast ( Nx . tensor ( 0 , type: tensor . type ) , output_shape )
337
- { num_windows , _ } = Nx . shape ( output )
338
-
339
- index_template =
340
- Nx . concatenate ( [ Nx . broadcast ( 0 , { window_length , 1 } ) , Nx . iota ( { window_length , 1 } ) ] , axis: 1 )
341
-
342
- { output , _ , _ , _ , _ } =
343
- while { output , i = 0 , current_window = 0 , t = Nx . pad ( tensor , 0 , padding ) , index_template } ,
344
- current_window < num_windows do
345
- indices = index_template + Nx . stack ( [ current_window , 0 ] )
346
- updates = t |> Nx . slice ( [ i ] , [ window_length ] ) |> Nx . flatten ( )
347
-
348
- updated = Nx . indexed_add ( output , indices , updates )
338
+ tensor = Nx . pad ( tensor , 0 , padding )
349
339
350
- { updated , i + stride , current_window + 1 , t , index_template }
351
- end
352
-
353
- output
340
+ as_windowed_apply ( tensor , stride , output_shape , window_length )
354
341
end
355
342
356
343
defnp as_windowed_reflect_padding ( tensor , opts \\ [ ] ) do
357
344
# current implementation only supports windowing 1D tensors
358
345
{ window_length , stride , _padding , output_shape } =
359
346
as_windowed_parse_reflect_opts ( Nx . shape ( tensor ) , opts )
360
347
361
- output = Nx . broadcast ( Nx . tensor ( 0 , type: tensor . type ) , output_shape )
362
- { num_windows , _ } = Nx . shape ( output )
363
-
364
- index_template =
365
- Nx . concatenate ( [ Nx . broadcast ( 0 , { window_length , 1 } ) , Nx . iota ( { window_length , 1 } ) ] , axis: 1 )
366
-
367
- leading_window_indices = generate_leading_window_indices ( window_length , stride )
368
-
369
- trailing_window_indices =
370
- generate_trailing_window_indices ( Nx . size ( tensor ) , window_length , stride )
371
-
372
- half_window = div ( window_length - 1 , 2 ) + 1
373
-
374
- { output , _ , _ , _ , _ } =
375
- while { output , i = 0 , current_window = 0 , t = tensor , index_template } ,
376
- current_window < num_windows do
377
- # Here windows are centered at the current index
378
-
379
- cond do
380
- i < half_window ->
381
- # We're indexing before we have a full window on the left
382
-
383
- window = Nx . take ( t , leading_window_indices [ i ] )
384
-
385
- indices = index_template + Nx . stack ( [ current_window , 0 ] )
386
- updated = Nx . indexed_add ( output , indices , window )
387
-
388
- { updated , i + stride , current_window + 1 , t , index_template }
389
-
390
- i > Nx . size ( t ) - half_window ->
391
- # We're indexing after the last full window on the right
392
- window = Nx . take ( t , trailing_window_indices [ i - ( Nx . size ( t ) - half_window + 1 ) ] )
393
-
394
- indices = index_template + Nx . stack ( [ current_window , 0 ] )
395
- updated = Nx . indexed_add ( output , indices , window )
396
-
397
- { updated , i + stride , current_window + 1 , t , index_template }
398
-
399
- true ->
400
- # Case where we can index a full window
401
- indices = index_template + Nx . stack ( [ current_window , 0 ] )
402
- updates = t |> Nx . slice ( [ i - half_window ] , [ window_length ] ) |> Nx . flatten ( )
403
-
404
- updated = Nx . indexed_add ( output , indices , updates )
405
-
406
- { updated , i + stride , current_window + 1 , t , index_template }
407
- end
408
- end
409
-
410
- # Now we need to handle the tail-end of the windows,
411
- # since they are currently all the same value. We want to apply the tapering-off
412
- # like we did with the initial windows.
413
-
414
- output
415
- end
416
-
417
- deftransformp generate_leading_window_indices ( window_length , stride ) do
418
348
half_window = div ( window_length , 2 )
349
+ tensor = Nx . reflect ( tensor , padding_config: [ { half_window , half_window } ] )
419
350
420
- for offset <- 0 .. half_window // stride do
421
- partial_length = offset + half_window
422
- padding_length = window_length - partial_length
423
-
424
- { partial_length }
425
- |> Nx . iota ( )
426
- |> Nx . reflect ( padding_config: [ { padding_length , 0 } ] )
427
- end
428
- |> Nx . stack ( )
351
+ as_windowed_apply ( tensor , stride , output_shape , window_length )
429
352
end
430
353
431
- deftransformp generate_trailing_window_indices ( tensor_size , window_length , stride ) do
432
- min_index = tensor_size - window_length + 1
354
+ defnp as_windowed_apply ( tensor , stride , output_shape , window_length ) do
355
+ output = Nx . broadcast ( Nx . tensor ( 0 , type: tensor . type ) , output_shape )
356
+ { num_windows , _ } = Nx . shape ( output )
433
357
434
- for { offset , add } <- Enum . with_index ( min_index .. ( tensor_size - 1 ) // stride ) do
435
- partial_length = tensor_size - offset
436
- padding_length = window_length - partial_length
358
+ { output , _ , _ , _ } =
359
+ while { output , i = 0 , current_window = 0 , t = tensor } , current_window < num_windows do
360
+ window = t |> Nx . slice ( [ i ] , [ window_length ] )
361
+ updated = Nx . put_slice ( output , [ current_window , 0 ] , Nx . new_axis ( window , 0 ) )
362
+ { updated , i + stride , current_window + 1 , t }
363
+ end
437
364
438
- { partial_length }
439
- |> Nx . iota ( )
440
- |> Nx . add ( min_index + add - rem ( window_length , 2 ) )
441
- |> Nx . reflect ( padding_config: [ { 0 , padding_length } ] )
442
- end
443
- |> Nx . stack ( )
365
+ output
444
366
end
445
367
446
368
@ doc """
@@ -548,15 +470,16 @@ defmodule NxSignal do
548
470
iex> Nx.axis_size(z, :frequencies)
549
471
16
550
472
iex> Nx.axis_size(z, :frames)
551
- 5
473
+ 6
552
474
iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4)
553
475
#Nx.Tensor<
554
- f32[frames: 5 ][mel: 4]
476
+ f32[frames: 6 ][mel: 4]
555
477
[
556
478
[0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825],
557
479
[0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537],
558
480
[0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198],
559
481
[0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989],
482
+ [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721],
560
483
[0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721]
561
484
]
562
485
>
0 commit comments