-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathKNN.py
More file actions
69 lines (57 loc) · 2.16 KB
/
KNN.py
File metadata and controls
69 lines (57 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.neighbors import KNeighborsClassifier
DATA_PATH = "data.json"
def load_data(data_path):
"""Loads training dataset from json file.
:param data_path (str): Path to json file containing data
:return X (ndarray): Inputs
:return y (ndarray): Targets
"""
with open(data_path, "r") as fp:
data = json.load(fp)
X = np.array(data["pitch"])
y = np.array(data["labels"])
return X, y
# get train, validation, test splits
X, y = load_data(DATA_PATH)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
#Create a knn Classifier
knn1 = KNeighborsClassifier(n_neighbors=1)
knn3 = KNeighborsClassifier(n_neighbors=3)
knn5 = KNeighborsClassifier(n_neighbors=5)
knn7 = KNeighborsClassifier(n_neighbors=7)
#Train the model using the training sets
knn1.fit(X_train, y_train)
knn3.fit(X_train, y_train)
knn5.fit(X_train, y_train)
knn7.fit(X_train, y_train)
#Predict the response for test dataset
y_pred1 = knn1.predict(X_test)
y_pred3 = knn3.predict(X_test)
y_pred5 = knn5.predict(X_test)
y_pred7 = knn7.predict(X_test)
print("test size: ", len(y_test))
print("train size: ",len(y_train))
# Model Accuracy, how often is the classifier correct?
#print("Accuracy:",metrics.accuracy_score(y_test, y_pred3))
err_count1=0
err_count3=0
err_count5=0
err_count7=0
for i in range(len(y_test)):
if y_test[i]!=y_pred1[i]:
err_count1+=1
if y_test[i]!=y_pred3[i]:
err_count3+=1
if y_test[i]!=y_pred5[i]:
err_count5+=1
if y_test[i]!=y_pred7[i]:
err_count7+=1
print("~~~Errors Comparison: ~~~")
print("Number of neighbor= ",1," ,number of errors: ",err_count1 ,", accuracy: ",metrics.accuracy_score(y_test, y_pred1) )
print("Number of neighbor= ",3," ,number of errors: ",err_count3 ,", accuracy: ",metrics.accuracy_score(y_test, y_pred3) )
print("Number of neighbor= ",5," ,number of errors: ",err_count5,", accuracy: ",metrics.accuracy_score(y_test, y_pred5) )
print("Number of neighbor= ",7," ,number of errors: ",err_count7,", accuracy: ",metrics.accuracy_score(y_test, y_pred7) )