Skip to content

Commit 4159154

Browse files
authored
Implement Short-time Fourier transform and its inverse (#587)
* Initial stft implementation * Finish STFT * Cleanup * Bump AMDGPU compat * Install GPU backends only when testing them * Add spectrogram * Move audio documentation to its own page - Convert examples to doctests or evaluate during build time -More tests * Fixes * Use Makie for spectrogram plots * Add mel-scale filterbanks * Run doctests when building documentation instead of a separate CI stage * Minor fix * Move FFTW-dependent functions to extension
1 parent 62f6074 commit 4159154

File tree

19 files changed

+808
-55
lines changed

19 files changed

+808
-55
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,6 @@ jobs:
9393
using Pkg
9494
Pkg.develop(PackageSpec(path=pwd()))
9595
Pkg.instantiate()'
96-
- run: |
97-
julia --color=yes --project=docs/ -e '
98-
using NNlib
99-
# using Pkg; Pkg.activate("docs")
100-
using Documenter
101-
using Documenter: doctest
102-
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive=true)
103-
doctest(NNlib)'
10496
- run: julia --project=docs docs/make.jl
10597
env:
10698
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Project.toml

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,51 +17,31 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1717
[weakdeps]
1818
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1919
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
20-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2120
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
21+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
22+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
2223

2324
[extensions]
2425
NNlibAMDGPUExt = "AMDGPU"
2526
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
2627
NNlibCUDAExt = "CUDA"
2728
NNlibEnzymeCoreExt = "EnzymeCore"
29+
NNlibFFTWExt = "FFTW"
2830

2931
[compat]
30-
AMDGPU = "0.8, 0.9"
32+
AMDGPU = "0.9.4"
3133
Adapt = "3.2, 4"
3234
Atomix = "0.1"
3335
CUDA = "4, 5"
36+
cuDNN = "1"
3437
ChainRulesCore = "1.13"
3538
EnzymeCore = "0.5, 0.6, 0.7"
39+
FFTW = "1.8.0"
3640
GPUArraysCore = "0.1"
3741
KernelAbstractions = "0.9.2"
3842
LinearAlgebra = "<0.0.1, 1"
3943
Pkg = "<0.0.1, 1"
4044
Random = "<0.0.1, 1"
4145
Requires = "1.0"
4246
Statistics = "1"
43-
cuDNN = "1"
4447
julia = "1.9"
45-
46-
[extras]
47-
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
48-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
49-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
50-
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
51-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
52-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
53-
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
54-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
55-
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
56-
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
57-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
58-
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
59-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
60-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
61-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
62-
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
63-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
64-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
65-
66-
[targets]
67-
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils", "Interpolations", "ImageTransformations"]

docs/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
[deps]
2+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99"
5+
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
6+
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
37
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
8+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
9+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

docs/make.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
using Documenter, NNlib
22

3-
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive = true)
3+
DocMeta.setdocmeta!(NNlib, :DocTestSetup,
4+
:(using FFTW, NNlib, UnicodePlots); recursive = true)
45

56
makedocs(modules = [NNlib],
6-
sitename = "NNlib.jl",
7-
doctest = false,
8-
pages = ["Home" => "index.md",
9-
"Reference" => "reference.md"],
10-
format = Documenter.HTML(
11-
canonical = "https://fluxml.ai/NNlib.jl/stable/",
12-
# analytics = "UA-36890222-9",
13-
assets = ["assets/flux.css"],
14-
prettyurls = get(ENV, "CI", nothing) == "true"),
15-
warnonly=[:missing_docs,]
16-
)
7+
sitename = "NNlib.jl",
8+
doctest = true,
9+
pages = ["Home" => "index.md",
10+
"Reference" => "reference.md",
11+
"Audio" => "audio.md"],
12+
format = Documenter.HTML(
13+
canonical = "https://fluxml.ai/NNlib.jl/stable/",
14+
# analytics = "UA-36890222-9",
15+
assets = ["assets/flux.css"],
16+
prettyurls = get(ENV, "CI", nothing) == "true"),
17+
warnonly=[:missing_docs,]
18+
)
1719

1820
deploydocs(repo = "github.com/FluxML/NNlib.jl.git",
1921
target = "build",

docs/src/assets/jfk.flac

195 KB
Binary file not shown.

docs/src/audio.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Reference
2+
3+
!!! note
4+
Spectral functions require importing `FFTW` package to enable them.
5+
6+
## Window functions
7+
8+
```@docs
9+
hann_window
10+
hamming_window
11+
```
12+
13+
## Spectral
14+
15+
```@docs
16+
stft
17+
istft
18+
NNlib.power_to_db
19+
NNlib.db_to_power
20+
```
21+
22+
## Spectrogram
23+
24+
```@docs
25+
melscale_filterbanks
26+
spectrogram
27+
```
28+
29+
Example:
30+
31+
```@example 1
32+
using FFTW # <- required for STFT support.
33+
using NNlib
34+
using FileIO
35+
using Makie, CairoMakie
36+
CairoMakie.activate!()
37+
38+
waveform, sampling_rate = load("./assets/jfk.flac")
39+
fig = lines(reshape(waveform, :))
40+
save("waveform.png", fig)
41+
42+
# Spectrogram.
43+
44+
n_fft = 1024
45+
spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft))
46+
fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1]))
47+
save("spectrogram.png", fig)
48+
49+
# Mel-scale spectrogram.
50+
51+
n_freqs = n_fft ÷ 2 + 1
52+
fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate))
53+
mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels)
54+
fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1])
55+
save("mel-spectrogram.png", fig)
56+
nothing # hide
57+
```
58+
59+
|Waveform|Spectrogram|Mel Spectrogram|
60+
|:---:|:---:|:---:|
61+
|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)|

ext/NNlibFFTWExt/NNlibFFTWExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module NNlibFFTWExt
2+
3+
using FFTW
4+
using NNlib
5+
using KernelAbstractions
6+
7+
include("stft.jl")
8+
9+
end

ext/NNlibFFTWExt/stft.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
function NNlib.stft(x;
2+
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
3+
center::Bool = true, normalized::Bool = false,
4+
)
5+
kab = get_backend(x)
6+
use_window = !isnothing(window)
7+
8+
use_window && kab != get_backend(window) && throw(ArgumentError(
9+
"`window` must be on the same device as stft input `x` ($kab), \
10+
instead: `$(get_backend(window))`."))
11+
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
12+
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
13+
but got `length(window)=$(length(window))`."))
14+
hop_length < 0 && throw(ArgumentError(
15+
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))
16+
17+
# Pad window on both sides with `0` to `n_fft` length if needed.
18+
if use_window && length(window) < n_fft
19+
left = ((n_fft - length(window)) ÷ 2) + 1
20+
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
21+
tmp[left:left + length(window) - 1] .= window
22+
window = tmp
23+
end
24+
25+
if center
26+
pad_amount = n_fft ÷ 2
27+
x = pad_reflect(x, pad_amount; dims=1)
28+
end
29+
30+
n = size(x, 1)
31+
(0 < n_fft n) || throw(ArgumentError(
32+
"Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`."))
33+
34+
n_frames = 1 + (n - n_fft) ÷ hop_length
35+
36+
# time2col.
37+
# Reshape `x` to (n_fft, n_frames, B) if needed.
38+
# Each row in `n_frames` is shifted by `hop_length`.
39+
if n_frames > 1
40+
# TODO can be more efficient if we support something like torch.as_strided
41+
ids = [
42+
row + hop_length * col
43+
for row in 1:n_fft, col in 0:(n_frames - 1)]
44+
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
45+
end
46+
47+
region = 1
48+
use_window && (x = x .* window;)
49+
y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)
50+
51+
normalized && (y = y .* eltype(y)(n_fft^-0.5);)
52+
return y
53+
end
54+
55+
function NNlib.istft(y;
56+
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
57+
center::Bool = true, normalized::Bool = false,
58+
return_complex::Bool = false,
59+
original_length::Union{Nothing, Int} = nothing,
60+
)
61+
kab = get_backend(y)
62+
use_window = !isnothing(window)
63+
64+
use_window && kab != get_backend(window) && throw(ArgumentError(
65+
"`window` must be on the same device as istft input `y` ($kab), \
66+
instead: `$(get_backend(window))`."))
67+
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
68+
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
69+
but got `length(window)=$(length(window))`."))
70+
hop_length < 0 && throw(ArgumentError(
71+
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))
72+
73+
# TODO check `y` eltype is complex
74+
75+
n_frames = size(y, 2)
76+
77+
# Pad window on both sides with `0` to `n_fft` length if needed.
78+
if use_window && length(window) < n_fft
79+
left = ((n_fft - length(window)) ÷ 2) + 1
80+
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
81+
tmp[left:left + length(window) - 1] .= window
82+
window = tmp
83+
end
84+
85+
# Denormalize.
86+
normalized && (y = y .* eltype(y)(n_fft^0.5);)
87+
88+
region = 1
89+
x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)
90+
91+
# De-apply window.
92+
use_window && (x = x ./ window;)
93+
94+
# col2time.
95+
expected_output_len = n_fft + hop_length * (n_frames - 1)
96+
97+
ids = Vector{Int}(undef, expected_output_len)
98+
in_idx, out_idx = 0, 0
99+
prev_e, v = 0, 0
100+
101+
for col in 0:(n_frames - 1)
102+
for row in 1:n_fft
103+
in_idx += 1
104+
v = row + hop_length * col
105+
v > prev_e || continue
106+
107+
out_idx += 1
108+
ids[out_idx] = in_idx
109+
end
110+
prev_e = v
111+
end
112+
113+
# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
114+
nd = ntuple(_ -> Colon(), ndims(x) - 2)
115+
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
116+
x = x[ids, nd...]
117+
118+
# Trim padding.
119+
left = center ? (n_fft ÷ 2 + 1) : 1
120+
right = if isnothing(original_length)
121+
center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len
122+
else
123+
left + original_length - 1
124+
end
125+
x = x[left:right, nd...]
126+
return x
127+
end

src/NNlib.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,9 @@ include("deprecations.jl")
124124
include("rotation.jl")
125125
export imrotate, ∇imrotate
126126

127+
include("audio/stft.jl")
128+
include("audio/spectrogram.jl")
129+
include("audio/mel.jl")
130+
export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks
131+
127132
end # module NNlib

0 commit comments

Comments
 (0)