Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions audio_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import numpy as np
from scipy.signal import get_window
import librosa.util as librosa_util


def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
n_fft=800, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.

This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.

Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`

n_frames : int > 0
The number of analysis frames

hop_length : int > 0
The number of samples to advance between frames

win_length : [optional]
The length of the window function. By default, this matches `n_fft`.

n_fft : int > 0
The length of each analysis frame.

dtype : np.dtype
The data type of the output

Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft

n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)

# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
win_sq = librosa_util.pad_center(win_sq, n_fft)

# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x


def griffin_lim(magnitudes, stft_fn, n_iters=30):
"""
PARAMS
------
magnitudes: spectrogram magnitudes
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
"""

angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
angles = angles.astype(np.float32)
angles = torch.autograd.Variable(torch.from_numpy(angles))
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)

for i in range(n_iters):
_, angles = stft_fn.transform(signal)
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
return signal


def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)


def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
37 changes: 37 additions & 0 deletions config_32k.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,

"upsample_rates": [8,8,4,2],
"upsample_kernel_sizes": [16,16,8,4],
"upsample_initial_channel": 256,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],

"segment_size": 16384,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 2048,
"hop_size": 512,
"win_size": 2048,

"sampling_rate": 32000,

"fmin": 0,
"fmax": 11025,
"fmax_for_loss": null,

"num_workers": 4,

"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}
37 changes: 37 additions & 0 deletions config_v1b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.00003,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.97,
"seed": 1234,

"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],

"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,

"sampling_rate": 22050,

"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,

"num_workers": 4,

"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}
38 changes: 38 additions & 0 deletions denoiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys
import torch
from stft import STFT


class Denoiser(torch.nn.Module):
""" WaveGlow denoiser, adapted for HiFi-GAN """

device = "cuda" if torch.cuda.is_available() else "cpu"

def __init__(
self, hifigan, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"
):
super(Denoiser, self).__init__()
self.stft = STFT(
filter_length=filter_length,
hop_length=int(filter_length / n_overlap),
win_length=win_length,
).to(Denoiser.device)
if mode == "zeros":
mel_input = torch.zeros((1, 80, 88)).to(Denoiser.device)
elif mode == "normal":
mel_input = torch.randn((1, 80, 88)).to(Denoiser.device)
else:
raise Exception("Mode {} if not supported".format(mode))

with torch.no_grad():
bias_audio = hifigan(mel_input).view(1, -1).float()
bias_spec, _ = self.stft.transform(bias_audio)

self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])

def forward(self, audio, strength=0.1):
audio_spec, audio_angles = self.stft.transform(audio.to(Denoiser.device).float())
audio_spec_denoised = audio_spec - self.bias_spec * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
return audio_denoised
File renamed without changes.
31 changes: 17 additions & 14 deletions inference_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ def load_checkpoint(filepath, device):


def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '*')
pattern = os.path.join(cp_dir, prefix + "*")
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return ''
return ""
return sorted(cp_list)[-1]


def inference(a):
generator = Generator(h).to(device)

state_dict_g = load_checkpoint(a.checkpoint_file, device)
generator.load_state_dict(state_dict_g['generator'])
generator.load_state_dict(state_dict_g["generator"])

filelist = os.listdir(a.input_mels_dir)

Expand All @@ -45,28 +45,32 @@ def inference(a):
generator.remove_weight_norm()
with torch.no_grad():
for i, filname in enumerate(filelist):
if ".npy" not in filname:
continue
x = np.load(os.path.join(a.input_mels_dir, filname))
x = torch.FloatTensor(x).to(device)
y_g_hat = generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
audio = audio.cpu().numpy().astype("int16")

output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav')
output_file = os.path.join(
a.output_dir, os.path.splitext(filname)[0] + ".wav"
)
write(output_file, h.sampling_rate, audio)
print(output_file)


def main():
print('Initializing Inference Process..')
print("Initializing Inference Process..")

parser = argparse.ArgumentParser()
parser.add_argument('--input_mels_dir', default='test_mel_files')
parser.add_argument('--output_dir', default='generated_files_from_mel')
parser.add_argument('--checkpoint_file', required=True)
parser.add_argument("--input_mels_dir", default="test_mel_files")
parser.add_argument("--output_dir", default="generated_files_from_mel")
parser.add_argument("--checkpoint_file", required=True)
a = parser.parse_args()

config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
with open(config_file) as f:
data = f.read()

Expand All @@ -78,13 +82,12 @@ def main():
global device
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda')
device = torch.device("cuda")
else:
device = torch.device('cpu')
device = torch.device("cpu")

inference(a)


if __name__ == '__main__':
if __name__ == "__main__":
main()

10 changes: 5 additions & 5 deletions meldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
y = y.squeeze(1)

spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)

spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.abs(spec)

spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
Expand All @@ -74,11 +74,11 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,

def get_dataset_filelist(a):
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0])
for x in fi.read().split('\n') if len(x) > 0]

with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0])
for x in fi.read().split('\n') if len(x) > 0]
return training_files, validation_files

Expand Down Expand Up @@ -141,7 +141,7 @@ def __getitem__(self, index):
center=False)
else:
mel = np.load(
os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
os.path.join(self.base_mels_path, os.path.splitext(filename)[0] + '.npy'))
mel = torch.from_numpy(mel)

if len(mel.shape) < 3:
Expand Down
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils import init_weights, get_padding
from hifiutils import init_weights, get_padding

LRELU_SLOPE = 0.1

Expand Down
Loading