-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathValidation.py
More file actions
109 lines (86 loc) · 4.95 KB
/
Validation.py
File metadata and controls
109 lines (86 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import tensorflow as tf
import numpy as np
import os
from Input import Input as Input
import Models.UnetAudioSeparator
import Estimate_Sources
import Utils
import functools
from tensorflow.contrib.signal.python.ops import window_ops
def test(model_config, audio_list, model_folder, load_model):
# Determine input and output shapes
disc_input_shape = [model_config["batch_size"], model_config["num_frames"], 0] # Shape of discriminator input
if model_config["network"] == "unet":
separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(model_config["num_layers"], model_config["num_initial_filters"],
output_type=model_config["output_type"],
context=model_config["context"],
mono=model_config["mono_downmix"],
upsampling=model_config["upsampling"],
num_sources=model_config["num_sources"],
filter_size=model_config["filter_size"],
merge_filter_size=model_config["merge_filter_size"])
else:
raise NotImplementedError
sep_input_shape, sep_output_shape = separator_class.get_padding(np.array(disc_input_shape))
separator_func = separator_class.get_output
# Creating the batch generators
assert ((sep_input_shape[1] - sep_output_shape[1]) % 2 == 0)
# Batch size of 1
sep_input_shape[0] = 1
sep_output_shape[0] = 1
mix_context, sources = Input.get_multitrack_placeholders(sep_output_shape, model_config["num_sources"], sep_input_shape, "input")
print("Testing...")
# BUILD MODELS
# Separator
separator_sources = separator_func(mix_context, False, False, reuse=False)
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int64)
# Start session and queue input threads
sess = tf.Session()
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep + model_folder, graph=sess.graph)
# CHECKPOINTING
# Load pretrained model to test
restorer = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
print("Num of variables" + str(len(tf.global_variables())))
restorer.restore(sess, load_model)
print('Pre-trained model restored for testing')
input_audio = tf.placeholder(tf.float32, shape=[None, 1])
window = functools.partial(window_ops.hann_window, periodic=True)
stft = tf.contrib.signal.stft(tf.squeeze(input_audio, 1), frame_length=1024, frame_step=768,
fft_length=1024, window_fn=window)
mag = tf.abs(stft)
# Start training loop
_global_step = sess.run(global_step)
print("Starting!")
total_loss = 0.0
total_samples = 0
for sample in audio_list: # Go through all tracks
# Load mixture and fetch prediction for mixture
mix_audio, mix_sr = Utils.load(sample[0].path, sr=None, mono=False)
sources_pred = Estimate_Sources.predict_track(model_config, sess, mix_audio, mix_sr, sep_input_shape, sep_output_shape, separator_sources, mix_context)
# Load original sources
sources_gt = list()
for s in sample[1:]:
s_audio, _ = Utils.load(s.path, sr=model_config["expected_sr"], mono=model_config["mono_downmix"])
sources_gt.append(s_audio)
# Determine mean squared error
for (source_gt, source_pred) in zip(sources_gt, sources_pred):
if model_config["network"] == "unet_spectrogram" and not model_config["raw_audio_loss"]:
real_mag = sess.run(mag, feed_dict={input_audio : source_gt})
pred_mag = sess.run(mag, feed_dict={input_audio: source_pred})
total_loss += np.sum(np.abs(real_mag - pred_mag))
total_samples += np.prod(real_mag.shape) # Number of entries is product of number of sources and number of outputs per source
else:
total_loss += np.sum(np.square(source_gt - source_pred))
total_samples += np.prod(source_gt.shape) # Number of entries is product of number of sources and number of outputs per source
print("MSE for track " + sample[0].path + ": " + str(total_loss / float(total_samples)))
mean_mse_loss = total_loss / float(total_samples)
summary = tf.Summary(value=[tf.Summary.Value(tag="test_loss", simple_value=mean_mse_loss)])
writer.add_summary(summary, global_step=_global_step)
writer.flush()
writer.close()
print("Finished testing - Mean MSE: " + str(mean_mse_loss))
# Close session, clear computational graph
sess.close()
tf.reset_default_graph()
return mean_mse_loss