1+ //! K-Nearest Neighbors (KNN) algorithm implementation
2+ //!
3+ //! KNN is a supervised machine learning algorithm used for classification and regression.
4+ //! It predicts the class/value of a data point based on the k nearest neighbors in the feature space.
5+ //!
6+ //! # Examples
7+ //!
8+ //! ```
9+ //! use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
10+ //!
11+ //! let mut knn = KNearestNeighbors::new(3);
12+ //!
13+ //! let training_data = vec![
14+ //! DataPoint::new(vec![1.0, 1.0], "A".to_string()),
15+ //! DataPoint::new(vec![2.0, 2.0], "A".to_string()),
16+ //! DataPoint::new(vec![5.0, 5.0], "B".to_string()),
17+ //! ];
18+ //!
19+ //! knn.fit(training_data);
20+ //!
21+ //! let prediction = knn.predict(&vec![1.5, 1.5]);
22+ //! assert_eq!(prediction, Some("A".to_string()));
23+ //! ```
24+
25+ use std:: collections:: HashMap ;
26+
27+ /// Represents a data point with features and a label
28+ #[ derive( Debug , Clone , PartialEq ) ] // Added PartialEq for better testing
29+ pub struct DataPoint {
30+ pub features : Vec < f64 > ,
31+ pub label : String ,
32+ }
33+
34+ impl DataPoint {
35+ /// Creates a new DataPoint
36+ ///
37+ /// # Arguments
38+ ///
39+ /// * `features` - Feature vector for the data point
40+ /// * `label` - Class label for the data point
41+ ///
42+ /// # Examples
43+ ///
44+ /// ```
45+ /// use the_algorithms_rust::machine_learning::DataPoint;
46+ ///
47+ /// let point = DataPoint::new(vec![1.0, 2.0], "A".to_string());
48+ /// ```
49+ pub fn new ( features : Vec < f64 > , label : String ) -> Self {
50+ DataPoint { features, label }
51+ }
52+ }
53+
54+ /// K-Nearest Neighbors classifier
55+ ///
56+ /// # Examples
57+ ///
58+ /// ```
59+ /// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
60+ ///
61+ /// let mut knn = KNearestNeighbors::new(3);
62+ /// ```
63+ #[ derive( Debug ) ]
64+ pub struct KNearestNeighbors {
65+ k : usize ,
66+ training_data : Vec < DataPoint > ,
67+ }
68+
69+ impl KNearestNeighbors {
70+ /// Creates a new KNN classifier with k neighbors
71+ ///
72+ /// # Arguments
73+ ///
74+ /// * `k` - Number of nearest neighbors to consider
75+ ///
76+ /// # Panics
77+ ///
78+ /// Panics if k is 0
79+ ///
80+ /// # Examples
81+ ///
82+ /// ```
83+ /// use the_algorithms_rust::machine_learning::KNearestNeighbors;
84+ ///
85+ /// let knn = KNearestNeighbors::new(3);
86+ /// ```
87+ pub fn new ( k : usize ) -> Self {
88+ assert ! ( k > 0 , "k must be greater than 0" ) ;
89+ KNearestNeighbors {
90+ k,
91+ training_data : Vec :: new ( ) ,
92+ }
93+ }
94+
95+ /// Trains the KNN model with training data
96+ ///
97+ /// # Arguments
98+ ///
99+ /// * `training_data` - Vector of labeled data points
100+ ///
101+ /// # Examples
102+ ///
103+ /// ```
104+ /// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
105+ ///
106+ /// let mut knn = KNearestNeighbors::new(3);
107+ /// let data = vec![DataPoint::new(vec![1.0, 2.0], "A".to_string())];
108+ /// knn.fit(data);
109+ /// ```
110+ pub fn fit ( & mut self , training_data : Vec < DataPoint > ) {
111+ self . training_data = training_data;
112+ }
113+
114+ /// Calculates Euclidean distance between two feature vectors
115+ ///
116+ /// # Panics
117+ ///
118+ /// Panics if feature vectors have different lengths
119+ fn euclidean_distance ( & self , a : & [ f64 ] , b : & [ f64 ] ) -> f64 {
120+ assert_eq ! (
121+ a. len( ) ,
122+ b. len( ) ,
123+ "Feature vectors must have the same length"
124+ ) ;
125+ a. iter ( )
126+ . zip ( b. iter ( ) )
127+ . map ( |( x, y) | ( x - y) . powi ( 2 ) )
128+ . sum :: < f64 > ( )
129+ . sqrt ( )
130+ }
131+
132+ /// Predicts the label for a given data point
133+ ///
134+ /// Returns `None` if training data is empty
135+ ///
136+ /// # Arguments
137+ ///
138+ /// * `features` - Feature vector to classify
139+ ///
140+ /// # Examples
141+ ///
142+ /// ```
143+ /// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
144+ ///
145+ /// let mut knn = KNearestNeighbors::new(1);
146+ /// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
147+ /// let result = knn.predict(&vec![1.5, 1.5]);
148+ /// assert_eq!(result, Some("A".to_string()));
149+ /// ```
150+ pub fn predict ( & self , features : & [ f64 ] ) -> Option < String > {
151+ if self . training_data . is_empty ( ) {
152+ return None ;
153+ }
154+
155+ // Calculate distances to all training points
156+ let mut distances: Vec < ( f64 , & DataPoint ) > = self
157+ . training_data
158+ . iter ( )
159+ . map ( |point| ( self . euclidean_distance ( features, & point. features ) , point) )
160+ . collect ( ) ;
161+
162+ // Sort by distance
163+ distances. sort_by ( |a, b| a. 0 . partial_cmp ( & b. 0 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
164+
165+ // Take k nearest neighbors
166+ let k_nearest = & distances[ ..self . k . min ( distances. len ( ) ) ] ;
167+
168+ // Count votes for each label
169+ let mut votes: HashMap < String , usize > = HashMap :: new ( ) ;
170+ for ( _, point) in k_nearest {
171+ * votes. entry ( point. label . clone ( ) ) . or_insert ( 0 ) += 1 ;
172+ }
173+
174+ // Return the label with the most votes
175+ votes
176+ . into_iter ( )
177+ . max_by_key ( |( _, count) | * count)
178+ . map ( |( label, _) | label)
179+ }
180+
181+ /// Predicts labels for multiple data points
182+ ///
183+ /// # Arguments
184+ ///
185+ /// * `features_batch` - Slice of feature vectors to classify
186+ ///
187+ /// # Examples
188+ ///
189+ /// ```
190+ /// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
191+ ///
192+ /// let mut knn = KNearestNeighbors::new(1);
193+ /// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
194+ /// let results = knn.predict_batch(&[vec![1.5, 1.5], vec![1.2, 1.2]]);
195+ /// ```
196+ pub fn predict_batch ( & self , features_batch : & [ Vec < f64 > ] ) -> Vec < Option < String > > {
197+ features_batch
198+ . iter ( )
199+ . map ( |features| self . predict ( features) )
200+ . collect ( )
201+ }
202+
203+ /// Calculates accuracy on test data
204+ ///
205+ /// Returns accuracy as a value between 0.0 and 1.0
206+ ///
207+ /// # Arguments
208+ ///
209+ /// * `test_data` - Test data points with known labels
210+ ///
211+ /// # Examples
212+ ///
213+ /// ```
214+ /// use the_algorithms_rust::machine_learning::{DataPoint, KNearestNeighbors};
215+ ///
216+ /// let mut knn = KNearestNeighbors::new(1);
217+ /// knn.fit(vec![DataPoint::new(vec![1.0, 1.0], "A".to_string())]);
218+ /// let test_data = vec![DataPoint::new(vec![1.1, 1.1], "A".to_string())];
219+ /// let accuracy = knn.score(&test_data);
220+ /// assert!(accuracy > 0.0);
221+ /// ```
222+ pub fn score ( & self , test_data : & [ DataPoint ] ) -> f64 {
223+ if test_data. is_empty ( ) {
224+ return 0.0 ;
225+ }
226+
227+ let correct = test_data
228+ . iter ( )
229+ . filter ( |point| {
230+ if let Some ( predicted) = self . predict ( & point. features ) {
231+ predicted == point. label
232+ } else {
233+ false
234+ }
235+ } )
236+ . count ( ) ;
237+
238+ correct as f64 / test_data. len ( ) as f64
239+ }
240+ }
241+
242+ #[ cfg( test) ]
243+ mod tests {
244+ use super :: * ;
245+
246+ #[ test]
247+ fn test_knn_simple_classification ( ) {
248+ let mut knn = KNearestNeighbors :: new ( 3 ) ;
249+
250+ let training_data = vec ! [
251+ DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
252+ DataPoint :: new( vec![ 1.5 , 1.5 ] , "A" . to_string( ) ) ,
253+ DataPoint :: new( vec![ 2.0 , 2.0 ] , "A" . to_string( ) ) ,
254+ DataPoint :: new( vec![ 5.0 , 5.0 ] , "B" . to_string( ) ) ,
255+ DataPoint :: new( vec![ 5.5 , 5.5 ] , "B" . to_string( ) ) ,
256+ DataPoint :: new( vec![ 6.0 , 6.0 ] , "B" . to_string( ) ) ,
257+ ] ;
258+
259+ knn. fit ( training_data) ;
260+
261+ assert_eq ! ( knn. predict( & vec![ 1.2 , 1.2 ] ) . unwrap( ) , "A" ) ;
262+ assert_eq ! ( knn. predict( & vec![ 5.2 , 5.2 ] ) . unwrap( ) , "B" ) ;
263+ }
264+
265+ #[ test]
266+ fn test_euclidean_distance ( ) {
267+ let knn = KNearestNeighbors :: new ( 1 ) ;
268+ let distance = knn. euclidean_distance ( & vec ! [ 0.0 , 0.0 ] , & vec ! [ 3.0 , 4.0 ] ) ;
269+ assert ! ( ( distance - 5.0 ) . abs( ) < f64 :: EPSILON ) ;
270+ }
271+
272+ #[ test]
273+ fn test_knn_with_k_equals_one ( ) {
274+ let mut knn = KNearestNeighbors :: new ( 1 ) ;
275+
276+ let training_data = vec ! [
277+ DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
278+ DataPoint :: new( vec![ 10.0 , 10.0 ] , "B" . to_string( ) ) ,
279+ ] ;
280+
281+ knn. fit ( training_data) ;
282+
283+ assert_eq ! ( knn. predict( & vec![ 1.5 , 1.5 ] ) . unwrap( ) , "A" ) ;
284+ assert_eq ! ( knn. predict( & vec![ 9.5 , 9.5 ] ) . unwrap( ) , "B" ) ;
285+ }
286+
287+ #[ test]
288+ fn test_knn_accuracy ( ) {
289+ let mut knn = KNearestNeighbors :: new ( 3 ) ;
290+
291+ let training_data = vec ! [
292+ DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
293+ DataPoint :: new( vec![ 1.5 , 1.5 ] , "A" . to_string( ) ) ,
294+ DataPoint :: new( vec![ 2.0 , 2.0 ] , "A" . to_string( ) ) ,
295+ DataPoint :: new( vec![ 5.0 , 5.0 ] , "B" . to_string( ) ) ,
296+ DataPoint :: new( vec![ 5.5 , 5.5 ] , "B" . to_string( ) ) ,
297+ DataPoint :: new( vec![ 6.0 , 6.0 ] , "B" . to_string( ) ) ,
298+ ] ;
299+
300+ knn. fit ( training_data) ;
301+
302+ let test_data = vec ! [
303+ DataPoint :: new( vec![ 1.2 , 1.2 ] , "A" . to_string( ) ) ,
304+ DataPoint :: new( vec![ 5.2 , 5.2 ] , "B" . to_string( ) ) ,
305+ ] ;
306+
307+ let accuracy = knn. score ( & test_data) ;
308+ assert ! ( ( accuracy - 1.0 ) . abs( ) < f64 :: EPSILON ) ;
309+ }
310+
311+ #[ test]
312+ fn test_predict_batch ( ) {
313+ let mut knn = KNearestNeighbors :: new ( 3 ) ;
314+
315+ let training_data = vec ! [
316+ DataPoint :: new( vec![ 1.0 , 1.0 ] , "A" . to_string( ) ) ,
317+ DataPoint :: new( vec![ 2.0 , 2.0 ] , "A" . to_string( ) ) ,
318+ DataPoint :: new( vec![ 5.0 , 5.0 ] , "B" . to_string( ) ) ,
319+ DataPoint :: new( vec![ 6.0 , 6.0 ] , "B" . to_string( ) ) ,
320+ ] ;
321+
322+ knn. fit ( training_data) ;
323+
324+ let features_batch = vec ! [ vec![ 1.5 , 1.5 ] , vec![ 5.5 , 5.5 ] ] ;
325+ let predictions = knn. predict_batch ( & features_batch) ;
326+
327+ assert_eq ! ( predictions[ 0 ] . as_ref( ) . unwrap( ) , "A" ) ;
328+ assert_eq ! ( predictions[ 1 ] . as_ref( ) . unwrap( ) , "B" ) ;
329+ }
330+
331+ #[ test]
332+ #[ should_panic( expected = "k must be greater than 0" ) ]
333+ fn test_knn_zero_k ( ) {
334+ KNearestNeighbors :: new ( 0 ) ;
335+ }
336+
337+ #[ test]
338+ fn test_empty_training_data ( ) {
339+ let knn = KNearestNeighbors :: new ( 3 ) ;
340+ assert ! ( knn. predict( & vec![ 1.0 , 1.0 ] ) . is_none( ) ) ;
341+ }
342+
343+ #[ test]
344+ #[ should_panic( expected = "Feature vectors must have the same length" ) ]
345+ fn test_mismatched_feature_lengths ( ) {
346+ let knn = KNearestNeighbors :: new ( 1 ) ;
347+ knn. euclidean_distance ( & vec ! [ 1.0 , 2.0 ] , & vec ! [ 1.0 ] ) ;
348+ }
349+ }
0 commit comments