-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
118 lines (99 loc) · 3.37 KB
/
test.py
File metadata and controls
118 lines (99 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import json
import os
from pathlib import Path
import torch
import torchaudio
from tqdm import tqdm
import vocoder.model as module_model
from vocoder.trainer import Trainer
from vocoder.utils import ROOT_PATH
from vocoder.utils.parse_config import ConfigParser
from vocoder.melspec import MelSpectrogram, MelSpectrogramConfig
DEFAULT_CHECKPOINT_PATH = ROOT_PATH / "default_test_model" / "checkpoint.pth"
def main(config, test_dir, output_dir, device):
logger = config.get_logger("test")
# define cpu or gpu if possible
device = torch.device(device)
# build model architecture
model = config.init_obj(config["arch"], module_model)
logger.info(model)
logger.info("Loading checkpoint: {} ...".format(config.resume))
checkpoint = torch.load(config.resume, map_location=device)
state_dict = checkpoint["state_dict"]
if config["n_gpu"] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
# prepare model for testing
logger.info(f"Device {device}")
model = model.to(device)
model.eval()
model.generator.remove_normalization()
os.makedirs(output_dir, exist_ok=True)
test_dir = Path(test_dir)
output_dir = Path(output_dir)
sampling_rate = 22050
mel_spec_config = MelSpectrogramConfig()
mel_spec_transform = MelSpectrogram(mel_spec_config).to(device)
with torch.no_grad():
for wav_path in tqdm(test_dir.iterdir(), "Processing wavs"):
wav = torchaudio.load(wav_path)[0].to(device)
mel_spec = mel_spec_transform(wav)
wav_pred = model.generator(mel_spec).squeeze(0).cpu()
torchaudio.save(output_dir / wav_path.name, wav_pred, sample_rate=sampling_rate)
if __name__ == "__main__":
args = argparse.ArgumentParser(description="PyTorch Template")
args.add_argument(
"-c",
"--config",
default=None,
type=str,
help="config file path (default: None)",
)
args.add_argument(
"-r",
"--resume",
default=str(DEFAULT_CHECKPOINT_PATH.absolute().resolve()),
type=str,
help="path to latest checkpoint (default: None)",
)
args.add_argument(
"-d",
"--device",
default=None,
type=str,
help="indices of GPUs to enable (default: all)",
)
args.add_argument(
"-t",
"--inference-dir",
default="test_audio",
type=str,
help="Directory with test audio wav files",
)
args.add_argument(
"-o",
"--output-dir",
default="output",
type=str,
help="Output directory",
)
args.add_argument(
"-j",
"--jobs",
default=1,
type=int,
help="Number of workers for test dataloader",
)
args = args.parse_args()
# first, we need to obtain config with model parameters
# we assume it is located with checkpoint in the same folder
# model_config = Path(args.resume).parent / "config_server.json"
model_config = Path(args.config)
with model_config.open() as f:
config = ConfigParser(json.load(f), resume=args.resume)
# update with addition configs from `args.config` if provided
if args.config is not None:
with Path(args.config).open() as f:
config.config.update(json.load(f))
main(config, args.inference_dir, args.output_dir, args.device)