|
| 1 | +""" |
| 2 | +K-Nearest Neighbors (KNN) - Simple Implementation from Scratch |
| 3 | +------------------------------------------------------------- |
| 4 | +This script implements a basic version of the KNN algorithm for classification |
| 5 | +using only Python and NumPy (no sklearn). |
| 6 | +
|
| 7 | +Concept Summary: |
| 8 | +---------------- |
| 9 | +1. KNN is a supervised learning algorithm used for classification & regression. |
| 10 | +2. It finds the K nearest data points to a test point using a distance metric |
| 11 | + (usually Euclidean distance). |
| 12 | +3. For classification → predicts the majority label among neighbors. |
| 13 | +4. For regression → predicts the average value among neighbors. |
| 14 | +5. It is a **lazy learner** (no explicit training phase, prediction happens at query time). |
| 15 | +
|
| 16 | +Steps in this code: |
| 17 | +------------------- |
| 18 | +1. Compute Euclidean distance between the test point and all training points. |
| 19 | +2. Sort training points by their distance to the test point. |
| 20 | +3. Select top 'k' nearest points. |
| 21 | +4. Use majority voting to determine the predicted class. |
| 22 | +5. Return the predicted label. |
| 23 | +
|
| 24 | +Contributor: |
| 25 | +--------------------- |
| 26 | +💻 Contributed by: **Lakhinana Chaturvedi Kashyap** |
| 27 | +""" |
| 28 | + |
| 29 | +import numpy as np |
| 30 | +import matplotlib.pyplot as plt # ✅ Correct import |
| 31 | +from collections import Counter |
| 32 | + |
| 33 | +# Function: Euclidean Distance |
| 34 | + |
| 35 | +def euclidean_distance(p1, p2): |
| 36 | + """ |
| 37 | + Calculates the Euclidean distance between two points. |
| 38 | +
|
| 39 | + Formula: |
| 40 | + √( (x2 - x1)² + (y2 - y1)² + ... ) |
| 41 | + """ |
| 42 | + return np.sqrt(np.sum((np.array(p1) - np.array(p2)) ** 2)) |
| 43 | + |
| 44 | + |
| 45 | +# Function: KNN Prediction |
| 46 | + |
| 47 | +def knn_prediction(training_data, training_labels, test_point, k): |
| 48 | + """ |
| 49 | + Predicts the class of a test point using the K-Nearest Neighbors algorithm. |
| 50 | + Returns both the predicted label and the k nearest points (for visualization). |
| 51 | + """ |
| 52 | + distances = [] |
| 53 | + for i in range(len(training_data)): |
| 54 | + dist = euclidean_distance(test_point, training_data[i]) |
| 55 | + distances.append((dist, training_labels[i], training_data[i])) # include point itself |
| 56 | + |
| 57 | + # Sort by distance (ascending) |
| 58 | + distances.sort(key=lambda x: x[0]) |
| 59 | + |
| 60 | + # Select top k neighbors |
| 61 | + k_neighbors = [label for _, label, _ in distances[:k]] |
| 62 | + nearest_points = [point for _, _, point in distances[:k]] |
| 63 | + |
| 64 | + # Majority voting |
| 65 | + prediction = Counter(k_neighbors).most_common(1)[0][0] |
| 66 | + |
| 67 | + return prediction, nearest_points |
| 68 | + |
| 69 | +# Example Usage |
| 70 | + |
| 71 | +# Convert to NumPy arrays for easy slicing |
| 72 | +training_data = np.array([ |
| 73 | + [1.0, 2.0], |
| 74 | + [2.0, 3.0], |
| 75 | + [3.0, 1.0], |
| 76 | + [6.0, 5.0], |
| 77 | + [7.0, 7.0], |
| 78 | + [8.0, 6.0] |
| 79 | +]) |
| 80 | +training_labels = np.array([0, 0, 0, 1, 1, 1]) |
| 81 | + |
| 82 | +# Test data |
| 83 | +test_point = np.array([5.0, 5.0]) |
| 84 | +k = 3 |
| 85 | + |
| 86 | +# Predict |
| 87 | +prediction, nearest_points = knn_prediction(training_data, training_labels, test_point, k) |
| 88 | + |
| 89 | +print("Predicted label:", prediction) |
| 90 | +print("Nearest neighbors:", nearest_points) |
| 91 | + |
| 92 | +# Visualization |
| 93 | + |
| 94 | +plt.figure(figsize=(8, 6)) |
| 95 | + |
| 96 | +# Plot class 0 points (blue) |
| 97 | +plt.scatter( |
| 98 | + training_data[training_labels == 0][:, 0], |
| 99 | + training_data[training_labels == 0][:, 1], |
| 100 | + color='blue', label='Class 0', s=100 |
| 101 | +) |
| 102 | + |
| 103 | +# Plot class 1 points (red) |
| 104 | +plt.scatter( |
| 105 | + training_data[training_labels == 1][:, 0], |
| 106 | + training_data[training_labels == 1][:, 1], |
| 107 | + color='red', label='Class 1', s=100 |
| 108 | +) |
| 109 | + |
| 110 | +# Highlight nearest neighbors (yellow) |
| 111 | +nearest_points = np.array(nearest_points) |
| 112 | +plt.scatter( |
| 113 | + nearest_points[:, 0], |
| 114 | + nearest_points[:, 1], |
| 115 | + edgecolor='black', facecolor='yellow', s=200, label=f'{k} Nearest Neighbors' |
| 116 | +) |
| 117 | + |
| 118 | +# Plot test point (green star) |
| 119 | +plt.scatter( |
| 120 | + test_point[0], test_point[1], |
| 121 | + color='green', marker='*', s=250, label='Test Point (Predicted)' |
| 122 | +) |
| 123 | + |
| 124 | +# Labels and title |
| 125 | +plt.title(f"KNN Visualization (k={k}) — Predicted Label: {prediction}") |
| 126 | +plt.xlabel("Feature 1") |
| 127 | +plt.ylabel("Feature 2") |
| 128 | +plt.legend() |
| 129 | +plt.grid(True) |
| 130 | +plt.show() |
0 commit comments