Skip to content

Commit 83bba97

Browse files
committed
added k_nearset_neighbors
1 parent ed7a42e commit 83bba97

File tree

3 files changed

+354
-0
lines changed

3 files changed

+354
-0
lines changed

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
* [K Means](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_means.rs)
164164
* [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs)
165165
* [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs)
166+
* [K Nearest Neigbors](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_nearest_neighbors.rs)
166167
* Loss Function
167168
* [Average Margin Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs)
168169
* [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs)
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
//! K-Nearest Neighbors (KNN) algorithm implementation
2+
//!
3+
//! KNN is a supervised machine learning algorithm used for classification and regression.
4+
//! It predicts the class/value of a data point based on the k nearest neighbors in the feature space.
5+
//!
6+
//! # Examples
7+
//!
8+
//! ```
9+
//! use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
10+
//!
11+
//! let mut knn = KNearestNeighbors::new(3);
12+
//!
13+
//! let training_data = vec![
14+
//! DataPoint::new(vec![1.0, 1.0], "A".to_string()),
15+
//! DataPoint::new(vec![2.0, 2.0], "A".to_string()),
16+
//! DataPoint::new(vec![5.0, 5.0], "B".to_string()),
17+
//! ];
18+
//!
19+
//! knn.fit(training_data);
20+
//!
21+
//! let prediction = knn.predict(&vec![1.5, 1.5]);
22+
//! assert_eq!(prediction, Some("A".to_string()));
23+
//! ```
24+
25+
use std::collections::HashMap;
26+
27+
/// Represents a data point with features and a label
28+
#[derive(Debug, Clone, PartialEq)] // Added PartialEq for better testing
29+
pub struct DataPoint {
30+
pub features: Vec<f64>,
31+
pub label: String,
32+
}
33+
34+
impl DataPoint {
35+
/// Creates a new DataPoint
36+
///
37+
/// # Arguments
38+
///
39+
/// * `features` - Feature vector for the data point
40+
/// * `label` - Class label for the data point
41+
///
42+
/// # Examples
43+
///
44+
/// ```
45+
/// use the_algorithms_rust::machine_learning::DataPoint;
46+
///
47+
/// let point = DataPoint::new(vec![1.0, 2.0], "A".to_string());
48+
/// ```
49+
pub fn new(features: Vec<f64>, label: String) -> Self {
50+
DataPoint { features, label }
51+
}
52+
}
53+
54+
/// K-Nearest Neighbors classifier
55+
///
56+
/// # Examples
57+
///
58+
/// ```
59+
/// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
60+
///
61+
/// let mut knn = KNearestNeighbors::new(3);
62+
/// ```
63+
#[derive(Debug)]
64+
pub struct KNearestNeighbors {
65+
k: usize,
66+
training_data: Vec<DataPoint>,
67+
}
68+
69+
impl KNearestNeighbors {
70+
/// Creates a new KNN classifier with k neighbors
71+
///
72+
/// # Arguments
73+
///
74+
/// * `k` - Number of nearest neighbors to consider
75+
///
76+
/// # Panics
77+
///
78+
/// Panics if k is 0
79+
///
80+
/// # Examples
81+
///
82+
/// ```
83+
/// use the_algorithms_rust::machine_learning::KNearestNeighbors;
84+
///
85+
/// let knn = KNearestNeighbors::new(3);
86+
/// ```
87+
pub fn new(k: usize) -> Self {
88+
assert!(k > 0, "k must be greater than 0");
89+
KNearestNeighbors {
90+
k,
91+
training_data: Vec::new(),
92+
}
93+
}
94+
95+
/// Trains the KNN model with training data
96+
///
97+
/// # Arguments
98+
///
99+
/// * `training_data` - Vector of labeled data points
100+
///
101+
/// # Examples
102+
///
103+
/// ```
104+
/// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
105+
///
106+
/// let mut knn = KNearestNeighbors::new(3);
107+
/// let data = vec![DataPoint::new(vec![1.0, 2.0], "A".to_string())];
108+
/// knn.fit(data);
109+
/// ```
110+
pub fn fit(&mut self, training_data: Vec<DataPoint>) {
111+
self.training_data = training_data;
112+
}
113+
114+
/// Calculates Euclidean distance between two feature vectors
115+
///
116+
/// # Panics
117+
///
118+
/// Panics if feature vectors have different lengths
119+
fn euclidean_distance(&self, a: &[f64], b: &[f64]) -> f64 {
120+
assert_eq!(
121+
a.len(),
122+
b.len(),
123+
"Feature vectors must have the same length"
124+
);
125+
a.iter()
126+
.zip(b.iter())
127+
.map(|(x, y)| (x - y).powi(2))
128+
.sum::<f64>()
129+
.sqrt()
130+
}
131+
132+
/// Predicts the label for a given data point
133+
///
134+
/// Returns `None` if training data is empty
135+
///
136+
/// # Arguments
137+
///
138+
/// * `features` - Feature vector to classify
139+
///
140+
/// # Examples
141+
///
142+
/// ```
143+
/// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
144+
///
145+
/// let mut knn = KNearestNeighbors::new(1);
146+
/// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
147+
/// let result = knn.predict(&vec![1.5, 1.5]);
148+
/// assert_eq!(result, Some("A".to_string()));
149+
/// ```
150+
pub fn predict(&self, features: &[f64]) -> Option<String> {
151+
if self.training_data.is_empty() {
152+
return None;
153+
}
154+
155+
// Calculate distances to all training points
156+
let mut distances: Vec<(f64, &DataPoint)> = self
157+
.training_data
158+
.iter()
159+
.map(|point| (self.euclidean_distance(features, &point.features), point))
160+
.collect();
161+
162+
// Sort by distance
163+
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
164+
165+
// Take k nearest neighbors
166+
let k_nearest = &distances[..self.k.min(distances.len())];
167+
168+
// Count votes for each label
169+
let mut votes: HashMap<String, usize> = HashMap::new();
170+
for (_, point) in k_nearest {
171+
*votes.entry(point.label.clone()).or_insert(0) += 1;
172+
}
173+
174+
// Return the label with the most votes
175+
votes
176+
.into_iter()
177+
.max_by_key(|(_, count)| *count)
178+
.map(|(label, _)| label)
179+
}
180+
181+
/// Predicts labels for multiple data points
182+
///
183+
/// # Arguments
184+
///
185+
/// * `features_batch` - Slice of feature vectors to classify
186+
///
187+
/// # Examples
188+
///
189+
/// ```
190+
/// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
191+
///
192+
/// let mut knn = KNearestNeighbors::new(1);
193+
/// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
194+
/// let results = knn.predict_batch(&[vec![1.5, 1.5], vec![1.2, 1.2]]);
195+
/// ```
196+
pub fn predict_batch(&self, features_batch: &[Vec<f64>]) -> Vec<Option<String>> {
197+
features_batch
198+
.iter()
199+
.map(|features| self.predict(features))
200+
.collect()
201+
}
202+
203+
/// Calculates accuracy on test data
204+
///
205+
/// Returns accuracy as a value between 0.0 and 1.0
206+
///
207+
/// # Arguments
208+
///
209+
/// * `test_data` - Test data points with known labels
210+
///
211+
/// # Examples
212+
///
213+
/// ```
214+
/// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
215+
///
216+
/// let mut knn = KNearestNeighbors::new(1);
217+
/// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
218+
/// let test_data = vec![DataPoint::new(vec![1.1, 1.1], "A".to_string())];
219+
/// let accuracy = knn.score(&test_data);
220+
/// assert!(accuracy > 0.0);
221+
/// ```
222+
pub fn score(&self, test_data: &[DataPoint]) -> f64 {
223+
if test_data.is_empty() {
224+
return 0.0;
225+
}
226+
227+
let correct = test_data
228+
.iter()
229+
.filter(|point| {
230+
if let Some(predicted) = self.predict(&point.features) {
231+
predicted == point.label
232+
} else {
233+
false
234+
}
235+
})
236+
.count();
237+
238+
correct as f64 / test_data.len() as f64
239+
}
240+
}
241+
242+
#[cfg(test)]
243+
mod tests {
244+
use super::*;
245+
246+
#[test]
247+
fn test_knn_simple_classification() {
248+
let mut knn = KNearestNeighbors::new(3);
249+
250+
let training_data = vec![
251+
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
252+
DataPoint::new(vec![1.5, 1.5], "A".to_string()),
253+
DataPoint::new(vec![2.0, 2.0], "A".to_string()),
254+
DataPoint::new(vec![5.0, 5.0], "B".to_string()),
255+
DataPoint::new(vec![5.5, 5.5], "B".to_string()),
256+
DataPoint::new(vec![6.0, 6.0], "B".to_string()),
257+
];
258+
259+
knn.fit(training_data);
260+
261+
assert_eq!(knn.predict(&vec![1.2, 1.2]).unwrap(), "A");
262+
assert_eq!(knn.predict(&vec![5.2, 5.2]).unwrap(), "B");
263+
}
264+
265+
#[test]
266+
fn test_euclidean_distance() {
267+
let knn = KNearestNeighbors::new(1);
268+
let distance = knn.euclidean_distance(&vec![0.0, 0.0], &vec![3.0, 4.0]);
269+
assert!((distance - 5.0).abs() < f64::EPSILON);
270+
}
271+
272+
#[test]
273+
fn test_knn_with_k_equals_one() {
274+
let mut knn = KNearestNeighbors::new(1);
275+
276+
let training_data = vec![
277+
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
278+
DataPoint::new(vec![10.0, 10.0], "B".to_string()),
279+
];
280+
281+
knn.fit(training_data);
282+
283+
assert_eq!(knn.predict(&vec![1.5, 1.5]).unwrap(), "A");
284+
assert_eq!(knn.predict(&vec![9.5, 9.5]).unwrap(), "B");
285+
}
286+
287+
#[test]
288+
fn test_knn_accuracy() {
289+
let mut knn = KNearestNeighbors::new(3);
290+
291+
let training_data = vec![
292+
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
293+
DataPoint::new(vec![1.5, 1.5], "A".to_string()),
294+
DataPoint::new(vec![2.0, 2.0], "A".to_string()),
295+
DataPoint::new(vec![5.0, 5.0], "B".to_string()),
296+
DataPoint::new(vec![5.5, 5.5], "B".to_string()),
297+
DataPoint::new(vec![6.0, 6.0], "B".to_string()),
298+
];
299+
300+
knn.fit(training_data);
301+
302+
let test_data = vec![
303+
DataPoint::new(vec![1.2, 1.2], "A".to_string()),
304+
DataPoint::new(vec![5.2, 5.2], "B".to_string()),
305+
];
306+
307+
let accuracy = knn.score(&test_data);
308+
assert!((accuracy - 1.0).abs() < f64::EPSILON);
309+
}
310+
311+
#[test]
312+
fn test_predict_batch() {
313+
let mut knn = KNearestNeighbors::new(3);
314+
315+
let training_data = vec![
316+
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
317+
DataPoint::new(vec![2.0, 2.0], "A".to_string()),
318+
DataPoint::new(vec![5.0, 5.0], "B".to_string()),
319+
DataPoint::new(vec![6.0, 6.0], "B".to_string()),
320+
];
321+
322+
knn.fit(training_data);
323+
324+
let features_batch = vec![vec![1.5, 1.5], vec![5.5, 5.5]];
325+
let predictions = knn.predict_batch(&features_batch);
326+
327+
assert_eq!(predictions[0].as_ref().unwrap(), "A");
328+
assert_eq!(predictions[1].as_ref().unwrap(), "B");
329+
}
330+
331+
#[test]
332+
#[should_panic(expected = "k must be greater than 0")]
333+
fn test_knn_zero_k() {
334+
KNearestNeighbors::new(0);
335+
}
336+
337+
#[test]
338+
fn test_empty_training_data() {
339+
let knn = KNearestNeighbors::new(3);
340+
assert!(knn.predict(&vec![1.0, 1.0]).is_none());
341+
}
342+
343+
#[test]
344+
#[should_panic(expected = "Feature vectors must have the same length")]
345+
fn test_mismatched_feature_lengths() {
346+
let knn = KNearestNeighbors::new(1);
347+
knn.euclidean_distance(&vec![1.0, 2.0], &vec![1.0]);
348+
}
349+
}

src/machine_learning/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ mod linear_regression;
44
mod logistic_regression;
55
mod loss_function;
66
mod optimization;
7+
mod k_nearest_neighbors;
8+
79

810
pub use self::cholesky::cholesky;
911
pub use self::k_means::k_means;
@@ -18,3 +20,5 @@ pub use self::loss_function::mse_loss;
1820
pub use self::loss_function::neg_log_likelihood;
1921
pub use self::optimization::gradient_descent;
2022
pub use self::optimization::Adam;
23+
pub use self::k_nearest_neighbors::{DataPoint, KNearestNeighbors};
24+

0 commit comments

Comments
 (0)