-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmir_eval_thresh_notes.py
More file actions
130 lines (102 loc) · 4.72 KB
/
mir_eval_thresh_notes.py
File metadata and controls
130 lines (102 loc) · 4.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import csv
import numpy as np
import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import mir_eval
from utils.util import load_config, notes_to_hz_mir_transcription, notes_to_hz_mir_multipitch, resize_target, get_ffm
from models.cnn import CNN
from core.audio_processor import AudioProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
import sys
exp_dir = sys.argv[1]
data_csv = sys.argv[2]
output_csv = f"{exp_dir}/notes_thresh.csv"
model_path = f"{exp_dir}/checkpoints/checkpoint_best.pt"
if not os.path.exists(model_path):
print("Best model not found. Using last model.")
model_path = f"{exp_dir}/checkpoints/checkpoint_last.pt"
config_path = f"{exp_dir}/.hydra/config.yaml"
config = load_config(config_path)
print(config)
if not config["predict_notes"]:
raise ValueError("This script works only for models that predict notes.")
cnn = CNN(config.model)
print(cnn)
cnn.load_state_dict(torch.load(model_path, map_location=device)["state_dict"])
cnn.eval()
cnn.to(device)
audio_processor = AudioProcessor(config.audio)
with open(output_csv, "w") as f:
f.write("threshold,precision,recall,f_measure\n")
with open(data_csv, "r") as f:
reader = csv.DictReader(f, delimiter=";")
data = list(reader)
instances = {}
for row in data:
audio_path = row["segment_path"]
print(os.path.basename(audio_path))
instance_name = os.path.basename(row["file_name"])
notes = torch.load(row["notes_path"], map_location="cpu").float().to(device)
frets = torch.load(row["frets_path"], map_location="cpu").float().to(device)
# notes = resize_target(notes, target_len=config.target_len_frames, upsample_method=config.data.target_len_frames_upsample_method).argmax(dim=-1)
frets = resize_target(frets, target_len=config.target_len_frames, upsample_method=config.data.target_len_frames_upsample_method)
if config.insert_ffm:
ffm_map_kernel = (
torch.tensor(list(config.ffm.map_kernel))
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.float()
).to(device)
ffm = get_ffm(frets, ffm_map_kernel=ffm_map_kernel).to(device).unsqueeze(0)
else:
ffm = None
audio = audio_processor.load_wav(audio_path).to(device)
feature = audio_processor.wav2feature(audio)
feature = torch.tensor(feature).to(device)
feature = feature.unsqueeze(0)
with torch.no_grad():
output = cnn(feature, ffm=ffm)["notes"]
target_notes = notes.squeeze().cpu().numpy()
if len(output.shape) == 3:
pred_notes = torch.nn.functional.sigmoid(output).squeeze().cpu().numpy()
elif len(output.shape) == 4:
pred_notes = torch.nn.functional.sigmoid(output).squeeze().cpu().numpy()[1] # not blank MCTC
if instance_name not in instances:
instances[instance_name] = {
'target_notes': target_notes,
'pred_notes': pred_notes
}
else:
instances[instance_name]['target_notes'] = np.concatenate((instances[instance_name]['target_notes'], target_notes), axis=1)
instances[instance_name]['pred_notes'] = np.concatenate((instances[instance_name]['pred_notes'], pred_notes), axis=1)
for thresh in np.arange(0.1, 1.0, 0.1):
print('-'*50)
print(f"Threshold: {thresh}:")
all_precision, all_recall, all_f_measure = [], [], []
for instance_name in instances:
target_notes = instances[instance_name]['target_notes']
pred_notes = instances[instance_name]['pred_notes']
target_notes = target_notes.astype(int)
pred_notes = (pred_notes >= thresh).astype(int)
ref_intervals, ref_pitches = notes_to_hz_mir_transcription(target_notes)
est_intervals, est_pitches = notes_to_hz_mir_transcription(pred_notes)
if est_intervals.size == 0:
print(f"{instance_name}: Estimated notes are empty. Maybe threshold is too high.")
continue
strings_precision, strings_recall, strings_f_measure = [], [], []
for s in range(6):
p, r, f, _ = mir_eval.transcription.precision_recall_f1_overlap(
ref_intervals, ref_pitches, est_intervals, est_pitches
)
strings_precision.append(p)
strings_recall.append(r)
strings_f_measure.append(f)
all_precision.append(np.mean(strings_precision))
all_recall.append(np.mean(strings_recall))
all_f_measure.append(np.mean(strings_f_measure))
print(f"Precision: {np.mean(all_precision)}, Recall: {np.mean(all_recall)}, F-measure: {np.mean(all_f_measure)}")
with open(output_csv, "a") as f:
f.write(f"{thresh},{np.mean(all_precision)},{np.mean(all_recall)},{np.mean(all_f_measure)}\n")