Skip to content

Commit 0eb0bcd

Browse files
committed
Uploaded code for LSTM music genre classifier.
1 parent 5531967 commit 0eb0bcd

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import json
2+
import numpy as np
3+
from sklearn.model_selection import train_test_split
4+
import tensorflow.keras as keras
5+
import matplotlib.pyplot as plt
6+
7+
DATA_PATH = "../13/data_10.json"
8+
9+
10+
def load_data(data_path):
11+
"""Loads training dataset from json file.
12+
13+
:param data_path (str): Path to json file containing data
14+
:return X (ndarray): Inputs
15+
:return y (ndarray): Targets
16+
"""
17+
18+
with open(data_path, "r") as fp:
19+
data = json.load(fp)
20+
21+
X = np.array(data["mfcc"])
22+
y = np.array(data["labels"])
23+
return X, y
24+
25+
26+
def plot_history(history):
27+
"""Plots accuracy/loss for training/validation set as a function of the epochs
28+
29+
:param history: Training history of model
30+
:return:
31+
"""
32+
33+
fig, axs = plt.subplots(2)
34+
35+
# create accuracy sublpot
36+
axs[0].plot(history.history["accuracy"], label="train accuracy")
37+
axs[0].plot(history.history["val_accuracy"], label="test accuracy")
38+
axs[0].set_ylabel("Accuracy")
39+
axs[0].legend(loc="lower right")
40+
axs[0].set_title("Accuracy eval")
41+
42+
# create error sublpot
43+
axs[1].plot(history.history["loss"], label="train error")
44+
axs[1].plot(history.history["val_loss"], label="test error")
45+
axs[1].set_ylabel("Error")
46+
axs[1].set_xlabel("Epoch")
47+
axs[1].legend(loc="upper right")
48+
axs[1].set_title("Error eval")
49+
50+
plt.show()
51+
52+
53+
def prepare_datasets(test_size, validation_size):
54+
"""Loads data and splits it into train, validation and test sets.
55+
56+
:param test_size (float): Value in [0, 1] indicating percentage of data set to allocate to test split
57+
:param validation_size (float): Value in [0, 1] indicating percentage of train set to allocate to validation split
58+
59+
:return X_train (ndarray): Input training set
60+
:return X_validation (ndarray): Input validation set
61+
:return X_test (ndarray): Input test set
62+
:return y_train (ndarray): Target training set
63+
:return y_validation (ndarray): Target validation set
64+
:return y_test (ndarray): Target test set
65+
"""
66+
67+
# load data
68+
X, y = load_data(DATA_PATH)
69+
70+
# create train, validation and test split
71+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
72+
X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size=validation_size)
73+
74+
return X_train, X_validation, X_test, y_train, y_validation, y_test
75+
76+
77+
def build_model(input_shape):
78+
"""Generates RNN-LSTM model
79+
80+
:param input_shape (tuple): Shape of input set
81+
:return model: RNN-LSTM model
82+
"""
83+
84+
# build network topology
85+
model = keras.Sequential()
86+
87+
# 2 LSTM layers
88+
model.add(keras.layers.LSTM(64, input_shape=input_shape, return_sequences=True))
89+
model.add(keras.layers.LSTM(64))
90+
91+
# dense layer
92+
model.add(keras.layers.Dense(64, activation='relu'))
93+
model.add(keras.layers.Dropout(0.3))
94+
95+
# output layer
96+
model.add(keras.layers.Dense(10, activation='softmax'))
97+
98+
return model
99+
100+
101+
if __name__ == "__main__":
102+
103+
# get train, validation, test splits
104+
X_train, X_validation, X_test, y_train, y_validation, y_test = prepare_datasets(0.25, 0.2)
105+
106+
# create network
107+
input_shape = (X_train.shape[1], X_train.shape[2]) # 130, 13
108+
model = build_model(input_shape)
109+
110+
# compile model
111+
optimiser = keras.optimizers.Adam(learning_rate=0.0001)
112+
model.compile(optimizer=optimiser,
113+
loss='sparse_categorical_crossentropy',
114+
metrics=['accuracy'])
115+
116+
model.summary()
117+
118+
# train model
119+
history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=32, epochs=30)
120+
121+
# plot accuracy/error for training and validation
122+
plot_history(history)
123+
124+
# evaluate model on test set
125+
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)
126+
print('\nTest accuracy:', test_acc)

0 commit comments

Comments
 (0)