-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_classifier.py
More file actions
55 lines (46 loc) · 1.96 KB
/
test_classifier.py
File metadata and controls
55 lines (46 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
from music_classifier import MusicClassifier
def main():
# Initialize the classifier with Random Forest
print("Initializing music classifier...")
classifier = MusicClassifier(model_type='rf')
# Check if we have any audio files in the data directory
data_dir = "data"
has_audio_files = False
for genre in os.listdir(data_dir):
genre_path = os.path.join(data_dir, genre)
if os.path.isdir(genre_path):
for file in os.listdir(genre_path):
if file.endswith(('.mp3', '.wav')):
has_audio_files = True
break
if not has_audio_files:
print("\nNo audio files found in the data directory!")
print("Please add some .mp3 or .wav files to the following directories:")
print("- data/rock/")
print("- data/jazz/")
print("- data/classical/")
return
# Prepare and train the model
print("\nPreparing dataset...")
X, y, classes = classifier.prepare_dataset(data_dir)
print("\nTraining model...")
X_test, y_test = classifier.train(X, y, tune_hyperparameters=True)
# Save the model
classifier.save_model()
print("\nModel saved successfully!")
# Test prediction on a sample file
print("\nTesting prediction on a sample file...")
for genre in os.listdir(data_dir):
genre_path = os.path.join(data_dir, genre)
if os.path.isdir(genre_path):
for file in os.listdir(genre_path):
if file.endswith(('.mp3', '.wav')):
test_file = os.path.join(genre_path, file)
print(f"\nTesting file: {test_file}")
genre, confidence = classifier.predict(test_file)
print(f"Predicted genre: {genre}")
print(f"Confidence: {confidence:.2f}")
break
if __name__ == "__main__":
main()