11//! K-Nearest Neighbors (KNN) algorithm implementation
2- //!
2+ //!
33//! KNN is a supervised machine learning algorithm used for classification and regression.
44//! It predicts the class/value of a data point based on the k nearest neighbors in the feature space.
55//!
1717//! ];
1818//!
1919//! knn.fit(training_data);
20- //!
21- //! let prediction = knn.predict(&vec! [1.5, 1.5]);
20+ //!
21+ //! let prediction = knn.predict(&[1.5, 1.5]);
2222//! assert_eq!(prediction, Some("A".to_string()));
2323//! ```
24-
2524use std:: collections:: HashMap ;
26-
2725/// Represents a data point with features and a label
28- #[ derive( Debug , Clone , PartialEq ) ] // Added PartialEq for better testing
26+ #[ derive( Debug , Clone , PartialEq ) ]
2927pub struct DataPoint {
3028 pub features : Vec < f64 > ,
3129 pub label : String ,
3230}
33-
3431impl DataPoint {
3532 /// Creates a new DataPoint
3633 ///
@@ -50,22 +47,20 @@ impl DataPoint {
5047 DataPoint { features, label }
5148 }
5249}
53-
5450/// K-Nearest Neighbors classifier
5551///
5652/// # Examples
5753///
5854/// ```
59- /// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors} ;
55+ /// use the_algorithms_rust::machine_learning::KNearestNeighbors;
6056///
61- /// let mut knn = KNearestNeighbors::new(3);
57+ /// let knn = KNearestNeighbors::new(3);
6258/// ```
6359#[ derive( Debug ) ]
6460pub struct KNearestNeighbors {
6561 k : usize ,
6662 training_data : Vec < DataPoint > ,
6763}
68-
6964impl KNearestNeighbors {
7065 /// Creates a new KNN classifier with k neighbors
7166 ///
@@ -91,7 +86,6 @@ impl KNearestNeighbors {
9186 training_data : Vec :: new ( ) ,
9287 }
9388 }
94-
9589 /// Trains the KNN model with training data
9690 ///
9791 /// # Arguments
@@ -110,7 +104,6 @@ impl KNearestNeighbors {
110104 pub fn fit ( & mut self , training_data : Vec < DataPoint > ) {
111105 self . training_data = training_data;
112106 }
113-
114107 /// Calculates Euclidean distance between two feature vectors
115108 ///
116109 /// # Panics
@@ -128,7 +121,6 @@ impl KNearestNeighbors {
128121 . sum :: < f64 > ( )
129122 . sqrt ( )
130123 }
131-
132124 /// Predicts the label for a given data point
133125 ///
134126 /// Returns `None` if training data is empty
@@ -144,40 +136,34 @@ impl KNearestNeighbors {
144136 ///
145137 /// let mut knn = KNearestNeighbors::new(1);
146138 /// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
147- /// let result = knn.predict(&vec! [1.5, 1.5]);
139+ /// let result = knn.predict(&[1.5, 1.5]);
148140 /// assert_eq!(result, Some("A".to_string()));
149141 /// ```
150142 pub fn predict ( & self , features : & [ f64 ] ) -> Option < String > {
151143 if self . training_data . is_empty ( ) {
152144 return None ;
153145 }
154-
155146 // Calculate distances to all training points
156147 let mut distances: Vec < ( f64 , & DataPoint ) > = self
157148 . training_data
158149 . iter ( )
159150 . map ( |point| ( self . euclidean_distance ( features, & point. features ) , point) )
160151 . collect ( ) ;
161-
162152 // Sort by distance
163153 distances. sort_by ( |a, b| a. 0 . partial_cmp ( & b. 0 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
164-
165154 // Take k nearest neighbors
166155 let k_nearest = & distances[ ..self . k . min ( distances. len ( ) ) ] ;
167-
168156 // Count votes for each label
169157 let mut votes: HashMap < String , usize > = HashMap :: new ( ) ;
170158 for ( _, point) in k_nearest {
171159 * votes. entry ( point. label . clone ( ) ) . or_insert ( 0 ) += 1 ;
172160 }
173-
174161 // Return the label with the most votes
175162 votes
176163 . into_iter ( )
177164 . max_by_key ( |( _, count) | * count)
178165 . map ( |( label, _) | label)
179166 }
180-
181167 /// Predicts labels for multiple data points
182168 ///
183169 /// # Arguments
@@ -199,7 +185,6 @@ impl KNearestNeighbors {
199185 . map ( |features| self . predict ( features) )
200186 . collect ( )
201187 }
202-
203188 /// Calculates accuracy on test data
204189 ///
205190 /// Returns accuracy as a value between 0.0 and 1.0
@@ -223,7 +208,6 @@ impl KNearestNeighbors {
223208 if test_data. is_empty ( ) {
224209 return 0.0 ;
225210 }
226-
227211 let correct = test_data
228212 . iter ( )
229213 . filter ( |point| {
@@ -234,19 +218,15 @@ impl KNearestNeighbors {
234218 }
235219 } )
236220 . count ( ) ;
237-
238221 correct as f64 / test_data. len ( ) as f64
239222 }
240223}
241-
242224#[ cfg( test) ]
243225mod tests {
244226 use super :: * ;
245-
246227 #[ test]
247228 fn test_knn_simple_classification ( ) {
248229 let mut knn = KNearestNeighbors :: new ( 3 ) ;
249-
250230 let training_data = vec ! [
251231 DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
252232 DataPoint :: new( vec![ 1.5 , 1.5 ] , "A" . to_string( ) ) ,
@@ -255,39 +235,30 @@ mod tests {
255235 DataPoint :: new( vec![ 5.5 , 5.5 ] , "B" . to_string( ) ) ,
256236 DataPoint :: new( vec![ 6.0 , 6.0 ] , "B" . to_string( ) ) ,
257237 ] ;
258-
259238 knn. fit ( training_data) ;
260-
261- assert_eq ! ( knn. predict( & vec![ 1.2 , 1.2 ] ) . unwrap( ) , "A" ) ;
262- assert_eq ! ( knn. predict( & vec![ 5.2 , 5.2 ] ) . unwrap( ) , "B" ) ;
239+ assert_eq ! ( knn. predict( & [ 1.2 , 1.2 ] ) . unwrap( ) , "A" ) ;
240+ assert_eq ! ( knn. predict( & [ 5.2 , 5.2 ] ) . unwrap( ) , "B" ) ;
263241 }
264-
265242 #[ test]
266243 fn test_euclidean_distance ( ) {
267244 let knn = KNearestNeighbors :: new ( 1 ) ;
268- let distance = knn. euclidean_distance ( & vec ! [ 0.0 , 0.0 ] , & vec ! [ 3.0 , 4.0 ] ) ;
245+ let distance = knn. euclidean_distance ( & [ 0.0 , 0.0 ] , & [ 3.0 , 4.0 ] ) ;
269246 assert ! ( ( distance - 5.0 ) . abs( ) < f64 :: EPSILON ) ;
270247 }
271-
272248 #[ test]
273249 fn test_knn_with_k_equals_one ( ) {
274250 let mut knn = KNearestNeighbors :: new ( 1 ) ;
275-
276251 let training_data = vec ! [
277252 DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
278253 DataPoint :: new( vec![ 10.0 , 10.0 ] , "B" . to_string( ) ) ,
279254 ] ;
280-
281255 knn. fit ( training_data) ;
282-
283- assert_eq ! ( knn. predict( & vec![ 1.5 , 1.5 ] ) . unwrap( ) , "A" ) ;
284- assert_eq ! ( knn. predict( & vec![ 9.5 , 9.5 ] ) . unwrap( ) , "B" ) ;
256+ assert_eq ! ( knn. predict( & [ 1.5 , 1.5 ] ) . unwrap( ) , "A" ) ;
257+ assert_eq ! ( knn. predict( & [ 9.5 , 9.5 ] ) . unwrap( ) , "B" ) ;
285258 }
286-
287259 #[ test]
288260 fn test_knn_accuracy ( ) {
289261 let mut knn = KNearestNeighbors :: new ( 3 ) ;
290-
291262 let training_data = vec ! [
292263 DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
293264 DataPoint :: new( vec![ 1.5 , 1.5 ] , "A" . to_string( ) ) ,
@@ -296,54 +267,84 @@ mod tests {
296267 DataPoint :: new( vec![ 5.5 , 5.5 ] , "B" . to_string( ) ) ,
297268 DataPoint :: new( vec![ 6.0 , 6.0 ] , "B" . to_string( ) ) ,
298269 ] ;
299-
300270 knn. fit ( training_data) ;
301-
302271 let test_data = vec ! [
303272 DataPoint :: new( vec![ 1.2 , 1.2 ] , "A" . to_string( ) ) ,
304273 DataPoint :: new( vec![ 5.2 , 5.2 ] , "B" . to_string( ) ) ,
305274 ] ;
306-
307275 let accuracy = knn. score ( & test_data) ;
308276 assert ! ( ( accuracy - 1.0 ) . abs( ) < f64 :: EPSILON ) ;
309277 }
310-
311278 #[ test]
312279 fn test_predict_batch ( ) {
313280 let mut knn = KNearestNeighbors :: new ( 3 ) ;
314-
315281 let training_data = vec ! [
316282 DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
317283 DataPoint :: new( vec![ 2.0 , 2.0 ] , "A" . to_string( ) ) ,
318284 DataPoint :: new( vec![ 5.0 , 5.0 ] , "B" . to_string( ) ) ,
319285 DataPoint :: new( vec![ 6.0 , 6.0 ] , "B" . to_string( ) ) ,
320286 ] ;
321-
322287 knn. fit ( training_data) ;
323-
324288 let features_batch = vec ! [ vec![ 1.5 , 1.5 ] , vec![ 5.5 , 5.5 ] ] ;
325289 let predictions = knn. predict_batch ( & features_batch) ;
326-
327290 assert_eq ! ( predictions[ 0 ] . as_ref( ) . unwrap( ) , "A" ) ;
328291 assert_eq ! ( predictions[ 1 ] . as_ref( ) . unwrap( ) , "B" ) ;
329292 }
330-
331293 #[ test]
332294 #[ should_panic( expected = "k must be greater than 0" ) ]
333295 fn test_knn_zero_k ( ) {
334296 KNearestNeighbors :: new ( 0 ) ;
335297 }
336-
337298 #[ test]
338299 fn test_empty_training_data ( ) {
339300 let knn = KNearestNeighbors :: new ( 3 ) ;
340- assert ! ( knn. predict( & vec! [ 1.0 , 1.0 ] ) . is_none( ) ) ;
301+ assert ! ( knn. predict( & [ 1.0 , 1.0 ] ) . is_none( ) ) ;
341302 }
342-
343303 #[ test]
344304 #[ should_panic( expected = "Feature vectors must have the same length" ) ]
345305 fn test_mismatched_feature_lengths ( ) {
346306 let knn = KNearestNeighbors :: new ( 1 ) ;
347- knn. euclidean_distance ( & vec ! [ 1.0 , 2.0 ] , & vec ! [ 1.0 ] ) ;
307+ knn. euclidean_distance ( & [ 1.0 , 2.0 ] , & [ 1.0 ] ) ;
308+ }
309+ #[ test]
310+ fn test_predict_batch_with_empty_training ( ) {
311+ let knn = KNearestNeighbors :: new ( 3 ) ;
312+ let features_batch = vec ! [ vec![ 1.5 , 1.5 ] , vec![ 5.5 , 5.5 ] ] ;
313+ let predictions = knn. predict_batch ( & features_batch) ;
314+ assert ! ( predictions[ 0 ] . is_none( ) ) ;
315+ assert ! ( predictions[ 1 ] . is_none( ) ) ;
316+ }
317+ #[ test]
318+ fn test_score_with_empty_test_data ( ) {
319+ let mut knn = KNearestNeighbors :: new ( 3 ) ;
320+ knn. fit ( vec ! [ DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ] ) ;
321+ let accuracy = knn. score ( & [ ] ) ;
322+ assert_eq ! ( accuracy, 0.0 ) ;
323+ }
324+ #[ test]
325+ fn test_k_larger_than_training_data ( ) {
326+ let mut knn = KNearestNeighbors :: new ( 10 ) ;
327+ let training_data = vec ! [
328+ DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
329+ DataPoint :: new( vec![ 2.0 , 2.0 ] , "A" . to_string( ) ) ,
330+ DataPoint :: new( vec![ 5.0 , 5.0 ] , "B" . to_string( ) ) ,
331+ ] ;
332+ knn. fit ( training_data) ;
333+ // Should still work even when k > training_data.len()
334+ assert_eq ! ( knn. predict( & [ 1.5 , 1.5 ] ) . unwrap( ) , "A" ) ;
335+ }
336+ #[ test]
337+ fn test_tie_breaking ( ) {
338+ let mut knn = KNearestNeighbors :: new ( 2 ) ;
339+ let training_data = vec ! [
340+ DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
341+ DataPoint :: new( vec![ 1.0 , 1.0 ] , "B" . to_string( ) ) ,
342+ ] ;
343+ knn. fit ( training_data) ;
344+ // When there's a tie, it should return one of them
345+ let result = knn. predict ( & [ 1.0 , 1.0 ] ) ;
346+ assert ! ( result. is_some( ) ) ;
347+ let prediction = result. unwrap ( ) ;
348+ assert ! ( prediction == "A" || prediction == "B" ) ;
348349 }
349350}
0 commit comments