Skip to content

Commit b222d3b

Browse files
authored
feat: use vectorization for overlap_and_add (#11)
* feat: use vectorization for overlap_and_add * fix: mix.exs version * fix: ci elixir version * refactor: use Nx.revectorize * fix: support type opt * docs
1 parent 1162660 commit b222d3b

File tree

4 files changed

+61
-41
lines changed

4 files changed

+61
-41
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
fail-fast: false
1313
matrix:
14-
elixir: ["1.13.0"]
14+
elixir: ["1.14.0"]
1515
otp: ["24.0"]
1616
env:
1717
MIX_ENV: test
@@ -42,7 +42,7 @@ jobs:
4242
strategy:
4343
fail-fast: false
4444
matrix:
45-
elixir: ["1.13.0"]
45+
elixir: ["1.14.0"]
4646
otp: ["24.0"]
4747
env:
4848
MIX_ENV: test

lib/nx_signal.ex

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ defmodule NxSignal do
718718

719719
@doc """
720720
Performs the overlap-and-add algorithm over
721-
an M by N tensor, where M is the number of
721+
an {..., M, N}-shaped tensor, where M is the number of
722722
windows and N is the window size.
723723
724724
The tensor is zero-padded on the right so
@@ -736,60 +736,80 @@ defmodule NxSignal do
736736
s64[12]
737737
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
738738
>
739+
739740
iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 3)
740741
#Nx.Tensor<
741742
s64[6]
742743
[0, 5, 15, 18, 17, 11]
743744
>
745+
746+
iex> t = Nx.tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]]], [[[10, 11, 12, 13], [14, 15, 16, 17]]]]) |> Nx.vectorize(x: 2, y: 1)
747+
iex> NxSignal.overlap_and_add(t, overlap_length: 3)
748+
#Nx.Tensor<
749+
vectorized[x: 2][y: 1]
750+
s64[5]
751+
[
752+
[
753+
[0, 5, 7, 9, 7]
754+
],
755+
[
756+
[10, 25, 27, 29, 17]
757+
]
758+
]
759+
>
744760
"""
745761
@doc type: :windowing
746762
defn overlap_and_add(tensor, opts \\ []) do
747-
opts = keyword!(opts, [:overlap_length])
748-
749-
{num_windows, window_length} = Nx.shape(tensor)
763+
opts = keyword!(opts, [:overlap_length, type: Nx.type(tensor)])
750764
overlap_length = opts[:overlap_length]
751765

766+
%{vectorized_axes: vectorized_axes, shape: input_shape} = tensor
767+
num_windows = Nx.axis_size(tensor, -2)
768+
window_length = Nx.axis_size(tensor, -1)
769+
752770
if overlap_length >= window_length do
753771
raise ArgumentError,
754772
"overlap_length must be a number less than the window size #{window_length}, got: #{inspect(window_length)}"
755773
end
756774

775+
tensor =
776+
Nx.revectorize(tensor, [condensed_vectors: :auto, windows: num_windows],
777+
target_shape: {window_length}
778+
)
779+
757780
stride = window_length - overlap_length
758781
output_holder_shape = {num_windows * stride + overlap_length}
759782

760-
{output, _, _, _, _, _} =
761-
while {
762-
out =
763-
Nx.broadcast(
764-
Nx.tensor(0, type: tensor.type),
765-
output_holder_shape
766-
),
767-
tensor,
768-
i = 0,
769-
idx_template = Nx.iota({window_length, 1}),
770-
stride,
771-
num_windows
772-
},
773-
i < num_windows do
774-
current_window = tensor[i]
775-
idx = idx_template + i * stride
776-
777-
{
778-
Nx.indexed_add(out, idx, current_window),
779-
tensor,
780-
i + 1,
781-
idx_template,
782-
stride,
783-
num_windows
784-
}
785-
end
783+
out =
784+
Nx.broadcast(
785+
Nx.tensor(0, type: tensor.type),
786+
output_holder_shape
787+
)
786788

787-
case opts[:type] do
788-
nil ->
789-
output
789+
idx_template = Nx.iota({window_length, 1}, vectorized_axes: [windows: 1])
790+
i = Nx.iota({num_windows}) |> Nx.vectorize(:windows)
791+
idx = idx_template + i * stride
790792

791-
t ->
792-
Nx.as_type(output, t)
793-
end
793+
[%{vectorized_axes: [condensed_vectors: n, windows: _]} = tensor, idx] =
794+
Nx.broadcast_vectors([tensor, idx])
795+
796+
tensor = Nx.revectorize(tensor, [condensed_vectors: n], target_shape: {:auto})
797+
idx = Nx.revectorize(idx, [condensed_vectors: n], target_shape: {:auto, 1})
798+
799+
out_shape = overlap_and_add_output_shape(out.shape, input_shape)
800+
801+
out
802+
|> Nx.indexed_add(idx, tensor)
803+
|> Nx.as_type(opts[:type])
804+
|> Nx.revectorize(vectorized_axes, target_shape: out_shape)
805+
end
806+
807+
deftransformp overlap_and_add_output_shape({out_len}, in_shape) do
808+
idx = tuple_size(in_shape) - 2
809+
810+
in_shape
811+
|> Tuple.delete_at(idx)
812+
|> Tuple.delete_at(idx)
813+
|> Tuple.append(out_len)
794814
end
795815
end

mix.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ defmodule NxSignal.MixProject do
88
[
99
app: :nx_signal,
1010
version: @version,
11-
elixir: "~> 1.13",
11+
elixir: "~> 1.14",
1212
start_permanent: Mix.env() == :prod,
1313
elixirc_paths: elixirc_paths(Mix.env()),
1414
deps: deps(),
@@ -56,7 +56,7 @@ defmodule NxSignal.MixProject do
5656
# Run "mix help deps" to learn about dependencies.
5757
defp deps do
5858
[
59-
{:nx, "~> 0.5"},
59+
{:nx, github: "elixir-nx/nx", sparse: "nx"},
6060
{:ex_doc, "~> 0.29", only: :docs}
6161
]
6262
end

mix.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
88
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
99
"nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"},
10-
"nx": {:hex, :nx, "0.5.0", "c5e62e82606ff372d986e72cce505c98421bb4305ce9cc8e439fe6cc1966c6ad", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "b29c246318181c3ebfcf0f230a0d33783ac4c92dfa34ca3aa5b9b38ae58c187e"},
10+
"nx": {:git, "https://github.com/elixir-nx/nx.git", "16ecbc6dbbde5fc5e122f8013601bcc4af2ef4c1", [sparse: "nx"]},
1111
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
1212
"xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"},
1313
}

0 commit comments

Comments
 (0)