Skip to content

Commit cb3c11a

Browse files
authored
Merge pull request #3 from Ipsedo/improvements-2.1
Improvements 2.1 : - simplified arch - weight norm
2 parents 73056fb + d74db81 commit cb3c11a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1142
-499
lines changed

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.pt filter=lfs diff=lfs merge=lfs -text
2+
*.wav filter=lfs diff=lfs merge=lfs -text

.pre-commit-config.yaml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,31 @@ repos:
1919
rev: 22.10.0
2020
hooks:
2121
- id: black
22-
args: [ --config=pyproject.toml ]
2322
- repo: https://github.com/hadialqattan/pycln
2423
rev: v2.1.2
2524
hooks:
2625
- id: pycln
27-
args: [ --config=pyproject.toml ]
26+
args: [ --config, pyproject.toml ]
2827
- repo: https://github.com/pycqa/isort
2928
rev: 5.12.0
3029
hooks:
3130
- id: isort
3231
files: "\\.(py)$"
33-
args: [ --settings-path=pyproject.toml ]
3432
- repo: local
3533
hooks:
3634
- id: mypy
3735
name: mypy
3836
language: system
3937
entry: mypy
40-
args: [ music_diffusion, tests, --config-file=pyproject.toml ]
38+
args: [ music_diffusion, tests ]
4139
types: [ python ]
4240
pass_filenames: false
4341
require_serial: true
4442
- id: pylint
4543
name: pylint
4644
language: system
4745
entry: pylint
48-
args: [ music_diffusion, tests, --rcfile, pyproject.toml ]
46+
args: [ music_diffusion, tests ]
4947
types: [ python ]
5048
pass_filenames: false
5149
require_serial: true

music_diffusion/__main__.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,19 @@ def main() -> None:
6161
model_parser = sub_command.add_parser("model")
6262

6363
model_parser.add_argument("--steps", type=int, default=4096)
64-
model_parser.add_argument("--beta-1", type=float, default=2.5e-5)
65-
model_parser.add_argument("--beta-t", type=float, default=5e-3)
66-
model_parser.add_argument("--channels", type=int, default=2)
6764
model_parser.add_argument(
6865
"--unet-channels",
6966
type=_channels,
7067
default=[
68+
(2, 16),
7169
(16, 32),
72-
(32, 48),
73-
(48, 64),
74-
(64, 80),
70+
(32, 64),
71+
(64, 128),
72+
(128, 256),
73+
(256, 512),
7574
],
7675
)
77-
model_parser.add_argument("--time-size", type=int, default=48)
78-
model_parser.add_argument("--norm-groups", type=int, default=16)
76+
model_parser.add_argument("--time-size", type=int, default=16)
7977
model_parser.add_argument("--cuda", action="store_true")
8078

8179
# Sub command run {train, generate}
@@ -97,6 +95,10 @@ def main() -> None:
9795
train_parser.add_argument("--save-every", type=int, default=4096)
9896
train_parser.add_argument("-o", "--output-dir", type=str, required=True)
9997
train_parser.add_argument("--nb-samples", type=int, default=5)
98+
train_parser.add_argument("--denoiser-state-dict", type=str)
99+
train_parser.add_argument("--ema-state-dict", type=str)
100+
train_parser.add_argument("--noiser-state-dict", type=str)
101+
train_parser.add_argument("--optim-state-dict", type=str)
100102

101103
# Generate parser
102104
generate_parser = model_sub_command.add_parser("generate")
@@ -106,6 +108,8 @@ def main() -> None:
106108
generate_parser.add_argument("--fast-sample", type=int, required=False)
107109
generate_parser.add_argument("--frames", type=int, required=True)
108110
generate_parser.add_argument("--musics", type=int, required=True)
111+
generate_parser.add_argument("--ema", action="store_true")
112+
generate_parser.add_argument("--magn-scale", type=float, default=1.0)
109113

110114
#######
111115
# Main
@@ -116,12 +120,8 @@ def main() -> None:
116120
if args.mode == "model":
117121
model_options = ModelOptions(
118122
steps=args.steps,
119-
beta_1=args.beta_1,
120-
beta_t=args.beta_t,
121-
input_channels=args.channels,
122123
unet_channels=args.unet_channels,
123124
time_size=args.time_size,
124-
norm_groups=args.norm_groups,
125125
cuda=args.cuda,
126126
)
127127

@@ -137,9 +137,10 @@ def main() -> None:
137137
save_every=args.save_every,
138138
output_directory=args.output_dir,
139139
nb_samples=args.nb_samples,
140-
noiser_state_dict=None,
141-
denoiser_state_dict=None,
142-
optim_state_dict=None,
140+
noiser_state_dict=args.noiser_state_dict,
141+
denoiser_state_dict=args.denoiser_state_dict,
142+
ema_state_dict=args.ema_state_dict,
143+
optim_state_dict=args.optim_state_dict,
143144
)
144145

145146
train(model_options, train_options)
@@ -148,9 +149,11 @@ def main() -> None:
148149
generate_options = GenerateOptions(
149150
fast_sample=args.fast_sample,
150151
denoiser_dict_state=args.denoiser_dict_state,
152+
ema_denoiser=args.ema,
151153
output_dir=args.output_dir,
152154
frames=args.frames,
153155
musics=args.musics,
156+
magn_scale=args.magn_scale,
154157
)
155158

156159
generate(model_options, generate_options)

music_diffusion/data/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
stft_to_magnitude_phase,
77
wav_to_stft,
88
)
9-
from .constants import N_FFT, N_VEC, OUTPUT_SIZES, SAMPLE_RATE, STFT_STRIDE
9+
from .constants import (
10+
BIN_SIZE,
11+
N_FFT,
12+
N_VEC,
13+
OUTPUT_SIZES,
14+
SAMPLE_RATE,
15+
STFT_STRIDE,
16+
)
1017
from .datasets import AudioDataset
1118
from .primitive import simpson, trapezoid
1219
from .transform import (

music_diffusion/data/audio.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,95 @@ def bark_scale(
4949
return res
5050

5151

52+
# copied code from
53+
# https://github.com/magenta/magenta/blob/main/magenta/models/gansynth/lib/spectral_ops.py
54+
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
55+
_MEL_HIGH_FREQUENCY_Q = 1127.0
56+
57+
58+
def mel_to_hertz(mel_values: th.Tensor) -> th.Tensor:
59+
return _MEL_BREAK_FREQUENCY_HERTZ * (
60+
th.exp(mel_values / _MEL_HIGH_FREQUENCY_Q) - 1.0
61+
)
62+
63+
64+
def hertz_to_mel(frequencies_hertz: th.Tensor) -> th.Tensor:
65+
return _MEL_HIGH_FREQUENCY_Q * th.log(
66+
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)
67+
)
68+
69+
70+
def linear_to_mel_weight_matrix(
71+
num_mel_bins: int = constants.N_FFT // 2,
72+
num_spectrogram_bins: int = constants.N_FFT // 2,
73+
sample_rate: int = constants.SAMPLE_RATE,
74+
lower_edge_hertz: float = 125.0,
75+
upper_edge_hertz: float = 3800.0,
76+
) -> th.Tensor:
77+
78+
# HTK excludes the spectrogram DC bin.
79+
bands_to_zero = 1
80+
nyquist_hertz = sample_rate / 2.0
81+
linear_frequencies = th.linspace(0.0, nyquist_hertz, num_spectrogram_bins)[
82+
bands_to_zero:, None
83+
]
84+
# spectrogram_bins_mel = hertz_to_mel(linear_frequencies)
85+
86+
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
87+
# center of each band is the lower and upper edge of the adjacent bands.
88+
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
89+
# num_mel_bins + 2 pieces.
90+
band_edges_mel = th.linspace(
91+
hertz_to_mel(th.tensor(lower_edge_hertz)).item(),
92+
hertz_to_mel(th.tensor(upper_edge_hertz)).item(),
93+
num_mel_bins + 2,
94+
)
95+
96+
lower_edge_mel = band_edges_mel[0:-2]
97+
center_mel = band_edges_mel[1:-1]
98+
upper_edge_mel = band_edges_mel[2:]
99+
100+
freq_res = nyquist_hertz / float(num_spectrogram_bins)
101+
freq_th = 1.5 * freq_res
102+
for i in range(0, num_mel_bins):
103+
center_hz = mel_to_hertz(center_mel[i])
104+
lower_hz = mel_to_hertz(lower_edge_mel[i])
105+
upper_hz = mel_to_hertz(upper_edge_mel[i])
106+
if upper_hz - lower_hz < freq_th:
107+
rhs = 0.5 * freq_th / (center_hz + _MEL_BREAK_FREQUENCY_HERTZ)
108+
dm = _MEL_HIGH_FREQUENCY_Q * th.log(rhs + th.sqrt(1.0 + rhs**2))
109+
lower_edge_mel[i] = center_mel[i] - dm
110+
upper_edge_mel[i] = center_mel[i] + dm
111+
112+
lower_edge_hz = mel_to_hertz(lower_edge_mel)[None, :]
113+
center_hz = mel_to_hertz(center_mel)[None, :]
114+
upper_edge_hz = mel_to_hertz(upper_edge_mel)[None, :]
115+
116+
# Calculate lower and upper slopes for every spectrogram bin.
117+
# Line segments are linear in the mel domain, not Hertz.
118+
lower_slopes = (linear_frequencies - lower_edge_hz) / (
119+
center_hz - lower_edge_hz
120+
)
121+
upper_slopes = (upper_edge_hz - linear_frequencies) / (
122+
upper_edge_hz - center_hz
123+
)
124+
125+
# Intersect the line segments with each other and zero.
126+
mel_weights_matrix = th.maximum(
127+
th.tensor(0.0), th.minimum(lower_slopes, upper_slopes)
128+
)
129+
130+
# Re-add the zeroed lower bins we sliced out above.
131+
# [freq, mel]
132+
mel_weights_matrix = th_f.pad(
133+
mel_weights_matrix, [bands_to_zero, 0, 0, 0], "constant"
134+
)
135+
return mel_weights_matrix
136+
137+
138+
# end of copied code
139+
140+
52141
def wav_to_stft(
53142
wav_p: str,
54143
n_per_seg: int = constants.N_FFT,
@@ -130,6 +219,8 @@ def magnitude_phase_to_wav(
130219
sample_rate: int,
131220
n_fft: int = constants.N_FFT,
132221
stft_stride: int = constants.STFT_STRIDE,
222+
threshold: float = 1.0 / 2**8,
223+
magn_scale: float = 1.0,
133224
) -> None:
134225
assert (
135226
len(magnitude_phase.size()) == 4
@@ -151,7 +242,9 @@ def magnitude_phase_to_wav(
151242
phase = magnitude_phase_flattened[1, :, :]
152243

153244
magnitude = (magnitude + 1.0) / 2.0
245+
magnitude[magnitude < threshold] = 0.0
154246
magnitude = bark_scale(magnitude, "unscale")
247+
magnitude = magnitude * magn_scale
155248

156249
phase = (phase + 1.0) / 2.0 * 2.0 * th.pi - th.pi
157250
phase = simpson(th.zeros(phase.size()[0], 1), phase, 1, 1.0)
@@ -191,34 +284,30 @@ def create_dataset(
191284
elif not isdir(dataset_output_dir):
192285
raise NotADirectoryError(dataset_output_dir)
193286

194-
n_per_seg = constants.N_FFT
195-
stride = constants.STFT_STRIDE
196-
197-
nb_vec = constants.N_VEC
198-
199287
idx = 0
200288

201289
for wav_p in tqdm(w_p):
202-
complex_values = wav_to_stft(wav_p, n_per_seg=n_per_seg, stride=stride)
290+
complex_values = wav_to_stft(
291+
wav_p, n_per_seg=constants.N_FFT, stride=constants.STFT_STRIDE
292+
)
203293

204-
if complex_values.size()[1] < nb_vec:
294+
if complex_values.size()[1] < constants.N_VEC:
205295
continue
206296

207297
magnitude, phase = stft_to_magnitude_phase(
208-
complex_values, nb_vec=nb_vec
298+
complex_values, nb_vec=constants.N_VEC
209299
)
210300

211301
nb_sample = magnitude.size()[0]
212302

213303
for s_idx in range(nb_sample):
214-
s_magnitude = magnitude[s_idx, :, :]
215-
s_phase = phase[s_idx, :, :]
216-
217304
magnitude_phase_path = join(
218305
dataset_output_dir, f"magn_phase_{idx}.pt"
219306
)
220307

221-
magnitude_phase = th.stack([s_magnitude, s_phase], dim=0)
308+
magnitude_phase = th.stack(
309+
[magnitude[s_idx, :, :], phase[s_idx, :, :]], dim=0
310+
)
222311
magnitude_phase = magnitude_phase.to(th.float)
223312

224313
th.save(magnitude_phase, magnitude_phase_path)

music_diffusion/data/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
SAMPLE_RATE: Final[int] = 16000
99

1010
OUTPUT_SIZES: Final[Tuple[int, int]] = (N_FFT // 2, N_VEC)
11+
12+
BIN_SIZE: Final[float] = 1.0 / 2.0**16.0

music_diffusion/data/transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch as th
55

6+
# pylint: disable=too-few-public-methods
7+
68

79
class ImgTransform(metaclass=ABCMeta):
810
@abstractmethod
@@ -55,3 +57,6 @@ def __init__(self, dtype: th.dtype) -> None:
5557

5658
def __call__(self, img_data: th.Tensor) -> th.Tensor:
5759
return img_data.to(self.__dtype)
60+
61+
62+
# pylint: enable=too-few-public-methods

music_diffusion/generate.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
STFT_STRIDE,
1313
magnitude_phase_to_wav,
1414
)
15-
from .networks import Denoiser
1615
from .options import GenerateOptions, ModelOptions
1716

1817

@@ -27,24 +26,28 @@ def generate(
2726

2827
print("Load model...")
2928

30-
# pylint: disable=duplicate-code
31-
denoiser = Denoiser(
32-
model_options.input_channels,
33-
model_options.steps,
34-
model_options.time_size,
35-
model_options.beta_1,
36-
model_options.beta_t,
37-
model_options.unet_channels,
38-
model_options.norm_groups,
39-
)
40-
# pylint: enable=duplicate-code
29+
denoiser = model_options.new_denoiser()
4130

4231
device = "cuda" if model_options.cuda else "cpu"
4332

44-
denoiser.load_state_dict(
45-
th.load(generate_options.denoiser_dict_state, map_location=device)
33+
loaded_state_dict = th.load(
34+
generate_options.denoiser_dict_state, map_location=device
4635
)
4736

37+
ema_prefix = "ema_model."
38+
39+
state_dict = (
40+
{
41+
k[len(ema_prefix) :]: p
42+
for k, p in loaded_state_dict.items()
43+
if k.startswith(ema_prefix)
44+
}
45+
if generate_options.ema_denoiser
46+
else loaded_state_dict
47+
)
48+
49+
denoiser.load_state_dict(state_dict)
50+
4851
denoiser.eval()
4952

5053
if model_options.cuda:
@@ -58,7 +61,7 @@ def generate(
5861

5962
x_t = th.randn(
6063
generate_options.musics,
61-
model_options.input_channels,
64+
model_options.unet_channels[0][0],
6265
height,
6366
width * generate_options.frames,
6467
device=device,
@@ -85,4 +88,5 @@ def generate(
8588
SAMPLE_RATE,
8689
N_FFT,
8790
STFT_STRIDE,
91+
magn_scale=generate_options.magn_scale,
8892
)

0 commit comments

Comments
 (0)