@@ -718,7 +718,7 @@ defmodule NxSignal do
718
718
719
719
@ doc """
720
720
Performs the overlap-and-add algorithm over
721
- an M by N tensor, where M is the number of
721
+ an {..., M, N}-shaped tensor, where M is the number of
722
722
windows and N is the window size.
723
723
724
724
The tensor is zero-padded on the right so
@@ -736,60 +736,80 @@ defmodule NxSignal do
736
736
s64[12]
737
737
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
738
738
>
739
+
739
740
iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 3)
740
741
#Nx.Tensor<
741
742
s64[6]
742
743
[0, 5, 15, 18, 17, 11]
743
744
>
745
+
746
+ iex> t = Nx.tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]]], [[[10, 11, 12, 13], [14, 15, 16, 17]]]]) |> Nx.vectorize(x: 2, y: 1)
747
+ iex> NxSignal.overlap_and_add(t, overlap_length: 3)
748
+ #Nx.Tensor<
749
+ vectorized[x: 2][y: 1]
750
+ s64[5]
751
+ [
752
+ [
753
+ [0, 5, 7, 9, 7]
754
+ ],
755
+ [
756
+ [10, 25, 27, 29, 17]
757
+ ]
758
+ ]
759
+ >
744
760
"""
745
761
@ doc type: :windowing
746
762
defn overlap_and_add ( tensor , opts \\ [ ] ) do
747
- opts = keyword! ( opts , [ :overlap_length ] )
748
-
749
- { num_windows , window_length } = Nx . shape ( tensor )
763
+ opts = keyword! ( opts , [ :overlap_length , type: Nx . type ( tensor ) ] )
750
764
overlap_length = opts [ :overlap_length ]
751
765
766
+ % { vectorized_axes: vectorized_axes , shape: input_shape } = tensor
767
+ num_windows = Nx . axis_size ( tensor , - 2 )
768
+ window_length = Nx . axis_size ( tensor , - 1 )
769
+
752
770
if overlap_length >= window_length do
753
771
raise ArgumentError ,
754
772
"overlap_length must be a number less than the window size #{ window_length } , got: #{ inspect ( window_length ) } "
755
773
end
756
774
775
+ tensor =
776
+ Nx . revectorize ( tensor , [ condensed_vectors: :auto , windows: num_windows ] ,
777
+ target_shape: { window_length }
778
+ )
779
+
757
780
stride = window_length - overlap_length
758
781
output_holder_shape = { num_windows * stride + overlap_length }
759
782
760
- { output , _ , _ , _ , _ , _ } =
761
- while {
762
- out =
763
- Nx . broadcast (
764
- Nx . tensor ( 0 , type: tensor . type ) ,
765
- output_holder_shape
766
- ) ,
767
- tensor ,
768
- i = 0 ,
769
- idx_template = Nx . iota ( { window_length , 1 } ) ,
770
- stride ,
771
- num_windows
772
- } ,
773
- i < num_windows do
774
- current_window = tensor [ i ]
775
- idx = idx_template + i * stride
776
-
777
- {
778
- Nx . indexed_add ( out , idx , current_window ) ,
779
- tensor ,
780
- i + 1 ,
781
- idx_template ,
782
- stride ,
783
- num_windows
784
- }
785
- end
783
+ out =
784
+ Nx . broadcast (
785
+ Nx . tensor ( 0 , type: tensor . type ) ,
786
+ output_holder_shape
787
+ )
786
788
787
- case opts [ :type ] do
788
- nil ->
789
- output
789
+ idx_template = Nx . iota ( { window_length , 1 } , vectorized_axes: [ windows: 1 ] )
790
+ i = Nx . iota ( { num_windows } ) |> Nx . vectorize ( :windows )
791
+ idx = idx_template + i * stride
790
792
791
- t ->
792
- Nx . as_type ( output , t )
793
- end
793
+ [ % { vectorized_axes: [ condensed_vectors: n , windows: _ ] } = tensor , idx ] =
794
+ Nx . broadcast_vectors ( [ tensor , idx ] )
795
+
796
+ tensor = Nx . revectorize ( tensor , [ condensed_vectors: n ] , target_shape: { :auto } )
797
+ idx = Nx . revectorize ( idx , [ condensed_vectors: n ] , target_shape: { :auto , 1 } )
798
+
799
+ out_shape = overlap_and_add_output_shape ( out . shape , input_shape )
800
+
801
+ out
802
+ |> Nx . indexed_add ( idx , tensor )
803
+ |> Nx . as_type ( opts [ :type ] )
804
+ |> Nx . revectorize ( vectorized_axes , target_shape: out_shape )
805
+ end
806
+
807
+ deftransformp overlap_and_add_output_shape ( { out_len } , in_shape ) do
808
+ idx = tuple_size ( in_shape ) - 2
809
+
810
+ in_shape
811
+ |> Tuple . delete_at ( idx )
812
+ |> Tuple . delete_at ( idx )
813
+ |> Tuple . append ( out_len )
794
814
end
795
815
end
0 commit comments