Skip to content

Commit a486ee2

Browse files
committed
Add K-Nearest Neighbor (KNN) implementation and Test
1 parent e2c20ed commit a486ee2

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using NUnit.Framework;
2+
using Algorithms.MachineLearning;
3+
using System;
4+
5+
namespace Algorithms.Tests.MachineLearning;
6+
7+
[TestFixture]
8+
public class KNearestNeighborsTests
9+
{
10+
[Test]
11+
public void Constructor_InvalidK_ThrowsException()
12+
{
13+
Assert.Throws<ArgumentOutOfRangeException>(() => new KNearestNeighbors<string>(0));
14+
}
15+
16+
[Test]
17+
public void AddSample_NullFeatures_ThrowsException()
18+
{
19+
var knn = new KNearestNeighbors<string>(3);
20+
double[]? features = null;
21+
Assert.Throws<ArgumentNullException>(() => knn.AddSample(features!, "A"));
22+
}
23+
24+
[Test]
25+
public void Predict_NoTrainingData_ThrowsException()
26+
{
27+
var knn = new KNearestNeighbors<string>(1);
28+
Assert.Throws<InvalidOperationException>(() => knn.Predict(new[] { 1.0 }));
29+
}
30+
31+
[Test]
32+
public void Predict_NullFeatures_ThrowsException()
33+
{
34+
var knn = new KNearestNeighbors<string>(1);
35+
knn.AddSample(new[] { 1.0 }, "A");
36+
double[]? features = null;
37+
Assert.Throws<ArgumentNullException>(() => knn.Predict(features!));
38+
}
39+
40+
[Test]
41+
public void EuclideanDistance_DifferentLengths_ThrowsException()
42+
{
43+
Assert.Throws<ArgumentException>(() => KNearestNeighbors<string>.EuclideanDistance(new[] { 1.0 }, new[] { 1.0, 2.0 }));
44+
}
45+
46+
[Test]
47+
public void EuclideanDistance_CorrectResult()
48+
{
49+
double[] a = { 1.0, 2.0 };
50+
double[] b = { 4.0, 6.0 };
51+
double expected = 5.0;
52+
double actual = KNearestNeighbors<string>.EuclideanDistance(a, b);
53+
Assert.That(actual, Is.EqualTo(expected).Within(1e-9));
54+
}
55+
56+
[Test]
57+
public void Predict_SingleNeighbor_CorrectLabel()
58+
{
59+
var knn = new KNearestNeighbors<string>(1);
60+
knn.AddSample(new double[] { 1.0, 2.0 }, "A");
61+
knn.AddSample(new double[] { 3.0, 4.0 }, "B");
62+
var label = knn.Predict(new[] { 1.1, 2.1 });
63+
Assert.That(label, Is.EqualTo("A"));
64+
}
65+
66+
[Test]
67+
public void Predict_MajorityVote_CorrectLabel()
68+
{
69+
var knn = new KNearestNeighbors<string>(3);
70+
knn.AddSample(new double[] { 0.0, 0.0 }, "A");
71+
knn.AddSample(new double[] { 0.1, 0.1 }, "A");
72+
knn.AddSample(new double[] { 1.0, 1.0 }, "B");
73+
var label = knn.Predict(new[] { 0.05, 0.05 });
74+
Assert.That(label, Is.EqualTo("A"));
75+
}
76+
77+
[Test]
78+
public void Predict_TieBreaker_ReturnsConsistentLabel()
79+
{
80+
var knn = new KNearestNeighbors<string>(2);
81+
knn.AddSample(new double[] { 0.0, 0.0 }, "A");
82+
knn.AddSample(new double[] { 1.0, 1.0 }, "B");
83+
var label = knn.Predict(new[] { 0.5, 0.5 });
84+
Assert.That(label, Is.EqualTo("A"));
85+
}
86+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
5+
namespace Algorithms.MachineLearning;
6+
7+
/// <summary>
8+
/// K-Nearest Neighbors (KNN) classifier implementation.
9+
/// This algorithm classifies data points based on the majority label of their k nearest neighbors.
10+
/// </summary>
11+
/// <typeparam name="TLabel">
12+
/// The type of the label used for classification. This can be any type that represents the class or category of a sample.
13+
/// </typeparam>
14+
public class KNearestNeighbors<TLabel>
15+
{
16+
private readonly List<(double[] Features, TLabel Label)> trainingData = new();
17+
private readonly int k;
18+
19+
/// <summary>
20+
/// Initializes a new instance of the <see cref="KNearestNeighbors{TLabel}"/> classifier.
21+
/// </summary>
22+
/// <param name="k">Number of neighbors to consider for classification.</param>
23+
/// <exception cref="ArgumentOutOfRangeException">Thrown if k is less than 1.</exception>
24+
public KNearestNeighbors(int k)
25+
{
26+
if (k < 1)
27+
{
28+
throw new ArgumentOutOfRangeException(nameof(k), "k must be at least 1.");
29+
}
30+
31+
this.k = k;
32+
}
33+
34+
/// <summary>
35+
/// Calculates the Euclidean distance between two feature vectors.
36+
/// </summary>
37+
/// <param name="a">First feature vector.</param>
38+
/// <param name="b">Second feature vector.</param>
39+
/// <returns>Euclidean distance.</returns>
40+
/// <exception cref="ArgumentException">Thrown if vectors are of different lengths.</exception>
41+
public static double EuclideanDistance(double[] a, double[] b)
42+
{
43+
if (a.Length != b.Length)
44+
{
45+
throw new ArgumentException("Feature vectors must be of the same length.");
46+
}
47+
48+
double sum = 0;
49+
for (int i = 0; i < a.Length; i++)
50+
{
51+
double diff = a[i] - b[i];
52+
sum += diff * diff;
53+
}
54+
55+
return Math.Sqrt(sum);
56+
}
57+
58+
/// <summary>
59+
/// Adds a training sample to the classifier.
60+
/// </summary>
61+
/// <param name="features">Feature vector of the sample.</param>
62+
/// <param name="label">Label of the sample.</param>
63+
public void AddSample(double[] features, TLabel label)
64+
{
65+
if (features == null)
66+
{
67+
throw new ArgumentNullException(nameof(features));
68+
}
69+
70+
trainingData.Add((features, label));
71+
}
72+
73+
/// <summary>
74+
/// Predicts the label for a given feature vector using the KNN algorithm.
75+
/// </summary>
76+
/// <param name="features">Feature vector to classify.</param>
77+
/// <returns>Predicted label.</returns>
78+
/// <exception cref="InvalidOperationException">Thrown if there is no training data.</exception>
79+
public TLabel Predict(double[] features)
80+
{
81+
if (trainingData.Count == 0)
82+
{
83+
throw new InvalidOperationException("No training data available.");
84+
}
85+
86+
if (features == null)
87+
{
88+
throw new ArgumentNullException(nameof(features));
89+
}
90+
91+
// Compute distances to all training samples
92+
var distances = trainingData
93+
.Select(td => (Label: td.Label, Distance: EuclideanDistance(features, td.Features)))
94+
.OrderBy(x => x.Distance)
95+
.Take(k)
96+
.ToList();
97+
98+
// Majority vote
99+
var labelCounts = distances
100+
.GroupBy(x => x.Label)
101+
.Select(g => new { Label = g.Key, Count = g.Count() })
102+
.OrderByDescending(x => x.Count)
103+
.ThenBy(x => x.Label?.GetHashCode() ?? 0)
104+
.ToList();
105+
106+
return labelCounts.First().Label;
107+
}
108+
}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ find more than one implementation for the same objective but using different alg
108108
* [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering)
109109
* [Machine Learning](./Algorithms/MachineLearning)
110110
* [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs)
111+
* [K-Nearest Neighbors](./Algorithms/MachineLearning/KNearestNeighbors.cs)
111112
* [Searches](./Algorithms/Search)
112113
* [A-Star](./Algorithms/Search/AStar/)
113114
* [Binary Search](./Algorithms/Search/BinarySearcher.cs)

0 commit comments

Comments
 (0)