Skip to content

Commit 78daeba

Browse files
committed
fixed
1 parent 83bba97 commit 78daeba

File tree

2 files changed

+58
-59
lines changed

2 files changed

+58
-59
lines changed

src/machine_learning/k_nearest_neighbors.rs

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! K-Nearest Neighbors (KNN) algorithm implementation
2-
//!
2+
//!
33
//! KNN is a supervised machine learning algorithm used for classification and regression.
44
//! It predicts the class/value of a data point based on the k nearest neighbors in the feature space.
55
//!
@@ -17,20 +17,17 @@
1717
//! ];
1818
//!
1919
//! knn.fit(training_data);
20-
//!
21-
//! let prediction = knn.predict(&vec![1.5, 1.5]);
20+
//!
21+
//! let prediction = knn.predict(&[1.5, 1.5]);
2222
//! assert_eq!(prediction, Some("A".to_string()));
2323
//! ```
24-
2524
use std::collections::HashMap;
26-
2725
/// Represents a data point with features and a label
28-
#[derive(Debug, Clone, PartialEq)] // Added PartialEq for better testing
26+
#[derive(Debug, Clone, PartialEq)]
2927
pub struct DataPoint {
3028
pub features: Vec<f64>,
3129
pub label: String,
3230
}
33-
3431
impl DataPoint {
3532
/// Creates a new DataPoint
3633
///
@@ -50,22 +47,20 @@ impl DataPoint {
5047
DataPoint { features, label }
5148
}
5249
}
53-
5450
/// K-Nearest Neighbors classifier
5551
///
5652
/// # Examples
5753
///
5854
/// ```
59-
/// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
55+
/// use the_algorithms_rust::machine_learning::KNearestNeighbors;
6056
///
61-
/// let mut knn = KNearestNeighbors::new(3);
57+
/// let knn = KNearestNeighbors::new(3);
6258
/// ```
6359
#[derive(Debug)]
6460
pub struct KNearestNeighbors {
6561
k: usize,
6662
training_data: Vec<DataPoint>,
6763
}
68-
6964
impl KNearestNeighbors {
7065
/// Creates a new KNN classifier with k neighbors
7166
///
@@ -91,7 +86,6 @@ impl KNearestNeighbors {
9186
training_data: Vec::new(),
9287
}
9388
}
94-
9589
/// Trains the KNN model with training data
9690
///
9791
/// # Arguments
@@ -110,7 +104,6 @@ impl KNearestNeighbors {
110104
pub fn fit(&mut self, training_data: Vec<DataPoint>) {
111105
self.training_data = training_data;
112106
}
113-
114107
/// Calculates Euclidean distance between two feature vectors
115108
///
116109
/// # Panics
@@ -128,7 +121,6 @@ impl KNearestNeighbors {
128121
.sum::<f64>()
129122
.sqrt()
130123
}
131-
132124
/// Predicts the label for a given data point
133125
///
134126
/// Returns `None` if training data is empty
@@ -144,40 +136,34 @@ impl KNearestNeighbors {
144136
///
145137
/// let mut knn = KNearestNeighbors::new(1);
146138
/// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
147-
/// let result = knn.predict(&vec![1.5, 1.5]);
139+
/// let result = knn.predict(&[1.5, 1.5]);
148140
/// assert_eq!(result, Some("A".to_string()));
149141
/// ```
150142
pub fn predict(&self, features: &[f64]) -> Option<String> {
151143
if self.training_data.is_empty() {
152144
return None;
153145
}
154-
155146
// Calculate distances to all training points
156147
let mut distances: Vec<(f64, &DataPoint)> = self
157148
.training_data
158149
.iter()
159150
.map(|point| (self.euclidean_distance(features, &point.features), point))
160151
.collect();
161-
162152
// Sort by distance
163153
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
164-
165154
// Take k nearest neighbors
166155
let k_nearest = &distances[..self.k.min(distances.len())];
167-
168156
// Count votes for each label
169157
let mut votes: HashMap<String, usize> = HashMap::new();
170158
for (_, point) in k_nearest {
171159
*votes.entry(point.label.clone()).or_insert(0) += 1;
172160
}
173-
174161
// Return the label with the most votes
175162
votes
176163
.into_iter()
177164
.max_by_key(|(_, count)| *count)
178165
.map(|(label, _)| label)
179166
}
180-
181167
/// Predicts labels for multiple data points
182168
///
183169
/// # Arguments
@@ -199,7 +185,6 @@ impl KNearestNeighbors {
199185
.map(|features| self.predict(features))
200186
.collect()
201187
}
202-
203188
/// Calculates accuracy on test data
204189
///
205190
/// Returns accuracy as a value between 0.0 and 1.0
@@ -223,7 +208,6 @@ impl KNearestNeighbors {
223208
if test_data.is_empty() {
224209
return 0.0;
225210
}
226-
227211
let correct = test_data
228212
.iter()
229213
.filter(|point| {
@@ -234,19 +218,15 @@ impl KNearestNeighbors {
234218
}
235219
})
236220
.count();
237-
238221
correct as f64 / test_data.len() as f64
239222
}
240223
}
241-
242224
#[cfg(test)]
243225
mod tests {
244226
use super::*;
245-
246227
#[test]
247228
fn test_knn_simple_classification() {
248229
let mut knn = KNearestNeighbors::new(3);
249-
250230
let training_data = vec![
251231
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
252232
DataPoint::new(vec![1.5, 1.5], "A".to_string()),
@@ -255,39 +235,30 @@ mod tests {
255235
DataPoint::new(vec![5.5, 5.5], "B".to_string()),
256236
DataPoint::new(vec![6.0, 6.0], "B".to_string()),
257237
];
258-
259238
knn.fit(training_data);
260-
261-
assert_eq!(knn.predict(&vec![1.2, 1.2]).unwrap(), "A");
262-
assert_eq!(knn.predict(&vec![5.2, 5.2]).unwrap(), "B");
239+
assert_eq!(knn.predict(&[1.2, 1.2]).unwrap(), "A");
240+
assert_eq!(knn.predict(&[5.2, 5.2]).unwrap(), "B");
263241
}
264-
265242
#[test]
266243
fn test_euclidean_distance() {
267244
let knn = KNearestNeighbors::new(1);
268-
let distance = knn.euclidean_distance(&vec![0.0, 0.0], &vec![3.0, 4.0]);
245+
let distance = knn.euclidean_distance(&[0.0, 0.0], &[3.0, 4.0]);
269246
assert!((distance - 5.0).abs() < f64::EPSILON);
270247
}
271-
272248
#[test]
273249
fn test_knn_with_k_equals_one() {
274250
let mut knn = KNearestNeighbors::new(1);
275-
276251
let training_data = vec![
277252
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
278253
DataPoint::new(vec![10.0, 10.0], "B".to_string()),
279254
];
280-
281255
knn.fit(training_data);
282-
283-
assert_eq!(knn.predict(&vec![1.5, 1.5]).unwrap(), "A");
284-
assert_eq!(knn.predict(&vec![9.5, 9.5]).unwrap(), "B");
256+
assert_eq!(knn.predict(&[1.5, 1.5]).unwrap(), "A");
257+
assert_eq!(knn.predict(&[9.5, 9.5]).unwrap(), "B");
285258
}
286-
287259
#[test]
288260
fn test_knn_accuracy() {
289261
let mut knn = KNearestNeighbors::new(3);
290-
291262
let training_data = vec![
292263
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
293264
DataPoint::new(vec![1.5, 1.5], "A".to_string()),
@@ -296,54 +267,84 @@ mod tests {
296267
DataPoint::new(vec![5.5, 5.5], "B".to_string()),
297268
DataPoint::new(vec![6.0, 6.0], "B".to_string()),
298269
];
299-
300270
knn.fit(training_data);
301-
302271
let test_data = vec![
303272
DataPoint::new(vec![1.2, 1.2], "A".to_string()),
304273
DataPoint::new(vec![5.2, 5.2], "B".to_string()),
305274
];
306-
307275
let accuracy = knn.score(&test_data);
308276
assert!((accuracy - 1.0).abs() < f64::EPSILON);
309277
}
310-
311278
#[test]
312279
fn test_predict_batch() {
313280
let mut knn = KNearestNeighbors::new(3);
314-
315281
let training_data = vec![
316282
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
317283
DataPoint::new(vec![2.0, 2.0], "A".to_string()),
318284
DataPoint::new(vec![5.0, 5.0], "B".to_string()),
319285
DataPoint::new(vec![6.0, 6.0], "B".to_string()),
320286
];
321-
322287
knn.fit(training_data);
323-
324288
let features_batch = vec![vec![1.5, 1.5], vec![5.5, 5.5]];
325289
let predictions = knn.predict_batch(&features_batch);
326-
327290
assert_eq!(predictions[0].as_ref().unwrap(), "A");
328291
assert_eq!(predictions[1].as_ref().unwrap(), "B");
329292
}
330-
331293
#[test]
332294
#[should_panic(expected = "k must be greater than 0")]
333295
fn test_knn_zero_k() {
334296
KNearestNeighbors::new(0);
335297
}
336-
337298
#[test]
338299
fn test_empty_training_data() {
339300
let knn = KNearestNeighbors::new(3);
340-
assert!(knn.predict(&vec![1.0, 1.0]).is_none());
301+
assert!(knn.predict(&[1.0, 1.0]).is_none());
341302
}
342-
343303
#[test]
344304
#[should_panic(expected = "Feature vectors must have the same length")]
345305
fn test_mismatched_feature_lengths() {
346306
let knn = KNearestNeighbors::new(1);
347-
knn.euclidean_distance(&vec![1.0, 2.0], &vec![1.0]);
307+
knn.euclidean_distance(&[1.0, 2.0], &[1.0]);
308+
}
309+
#[test]
310+
fn test_predict_batch_with_empty_training() {
311+
let knn = KNearestNeighbors::new(3);
312+
let features_batch = vec![vec![1.5, 1.5], vec![5.5, 5.5]];
313+
let predictions = knn.predict_batch(&features_batch);
314+
assert!(predictions[0].is_none());
315+
assert!(predictions[1].is_none());
316+
}
317+
#[test]
318+
fn test_score_with_empty_test_data() {
319+
let mut knn = KNearestNeighbors::new(3);
320+
knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
321+
let accuracy = knn.score(&[]);
322+
assert_eq!(accuracy, 0.0);
323+
}
324+
#[test]
325+
fn test_k_larger_than_training_data() {
326+
let mut knn = KNearestNeighbors::new(10);
327+
let training_data = vec![
328+
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
329+
DataPoint::new(vec![2.0, 2.0], "A".to_string()),
330+
DataPoint::new(vec![5.0, 5.0], "B".to_string()),
331+
];
332+
knn.fit(training_data);
333+
// Should still work even when k > training_data.len()
334+
assert_eq!(knn.predict(&[1.5, 1.5]).unwrap(), "A");
335+
}
336+
#[test]
337+
fn test_tie_breaking() {
338+
let mut knn = KNearestNeighbors::new(2);
339+
let training_data = vec![
340+
DataPoint::new(vec![1.0, 1.0], "A".to_string()),
341+
DataPoint::new(vec![1.0, 1.0], "B".to_string()),
342+
];
343+
knn.fit(training_data);
344+
// When there's a tie, it should return one of them
345+
let result = knn.predict(&[1.0, 1.0]);
346+
assert!(result.is_some());
347+
let prediction = result.unwrap();
348+
assert!(prediction == "A" || prediction == "B");
348349
}
349350
}

src/machine_learning/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
mod cholesky;
22
mod k_means;
3+
mod k_nearest_neighbors;
34
mod linear_regression;
45
mod logistic_regression;
56
mod loss_function;
67
mod optimization;
7-
mod k_nearest_neighbors;
8-
98

109
pub use self::cholesky::cholesky;
1110
pub use self::k_means::k_means;
11+
pub use self::k_nearest_neighbors::{DataPoint, KNearestNeighbors};
1212
pub use self::linear_regression::linear_regression;
1313
pub use self::logistic_regression::logistic_regression;
1414
pub use self::loss_function::average_margin_ranking_loss;
@@ -19,6 +19,4 @@ pub use self::loss_function::mae_loss;
1919
pub use self::loss_function::mse_loss;
2020
pub use self::loss_function::neg_log_likelihood;
2121
pub use self::optimization::gradient_descent;
22-
pub use self::optimization::Adam;
23-
pub use self::k_nearest_neighbors::{DataPoint, KNearestNeighbors};
24-
22+
pub use self::optimization::Adam;

0 commit comments

Comments
 (0)