@@ -211,6 +211,12 @@ impl<F: Primitive> KMeans<F> {
211211 /// let labels = model.predict(&data).unwrap();
212212 /// assert_eq!(labels.len(), 2);
213213 /// ```
214+ ///
215+ /// # Notes
216+ ///
217+ /// This method assumes that the input `points` contains only finite floating-point values.
218+ /// Accessing `NaN` or `Infinity` in the input may result in undefined behavior or
219+ /// failure to produce a model.
214220 pub fn fit_default_scalar ( points : impl AsRef < [ F ] > , ncols : usize , k : usize ) -> Result < Self > {
215221 KMeansBuilder :: new ( k)
216222 . cpu_scalar ( )
@@ -235,6 +241,12 @@ impl<F: Primitive> KMeans<F> {
235241 /// let labels = model.predict_sequential(&data).unwrap();
236242 /// assert_eq!(labels.len(), 2);
237243 /// ```
244+ ///
245+ /// # Notes
246+ ///
247+ /// This method assumes that the input `points` contains only finite floating-point values.
248+ /// Passing `NaN` or `Infinity` in the input may result in undefined behavior or
249+ /// incorrect predictions.
238250 pub fn predict_sequential ( & self , points : impl AsRef < [ F ] > ) -> Result < Vec < usize > > {
239251 self . predict_with_backend_sequential :: < CPUScalar > ( points)
240252 }
@@ -261,7 +273,7 @@ impl<F: Primitive> KMeans<F> {
261273 ) -> Result < Vec < usize > > {
262274 self . validate_model_shape ( ) ?;
263275 let points = points. as_ref ( ) ;
264- validate_inputs ( points, self . ncols , self . k ) ?;
276+ validate_prediction_inputs ( points, self . ncols ) ?;
265277
266278 let npoints = points. len ( ) / self . ncols ;
267279 let mut labels = vec ! [ 0usize ; npoints] ;
@@ -347,7 +359,7 @@ impl<F: Primitive> KMeans<F> {
347359 source : & S ,
348360 ) -> Result < Vec < usize > > {
349361 self . validate_model_shape ( ) ?;
350- validate_source_inputs ( source, self . k ) ?;
362+ validate_prediction_source_inputs ( source, self . ncols ) ?;
351363 let npoints = source. num_points ( ) ;
352364 let mut labels = vec ! [ 0usize ; npoints] ;
353365
@@ -393,10 +405,16 @@ impl<F: Primitive> KMeans<F> {
393405 /// For `MetricType::Euclidean` this returns squared Euclidean distances,
394406 /// for `MetricType::DotProduct` it returns raw dot-product similarities.
395407 /// Returns a flat vector of shape (n_points * k).
408+ ///
409+ /// # Notes
410+ ///
411+ /// This method assumes that the input `points` contains only finite floating-point values.
412+ /// Passing `NaN` or `Infinity` in the input may result in undefined behavior or
413+ /// incorrect results.
396414 pub fn transform ( & self , points : impl AsRef < [ F ] > ) -> Result < Vec < F > > {
397415 self . validate_model_shape ( ) ?;
398416 let points = points. as_ref ( ) ;
399- validate_inputs ( points, self . ncols , self . k ) ?;
417+ validate_prediction_inputs ( points, self . ncols ) ?;
400418
401419 let npoints = points. len ( ) / self . ncols ;
402420 let mut distances = vec ! [ F :: zero( ) ; npoints * self . k] ;
@@ -556,7 +574,7 @@ impl<F: Primitive> KMeansBuilder<F> {
556574 k,
557575 iterations : 100 ,
558576 attempts : 1 ,
559- tolerance : F :: from ( 1e-4 ) . unwrap_or ( F :: zero ( ) ) ,
577+ tolerance : F :: from ( 1e-4 ) . unwrap_or ( F :: epsilon ( ) ) ,
560578 seed : None ,
561579 mini_batch_rel_tolerance : kmeans_mini_batch:: DEFAULT_MINI_BATCH_REL_TOL ,
562580 mini_batch_min_iterations : kmeans_mini_batch:: DEFAULT_MINI_BATCH_MIN_ITERATIONS ,
@@ -585,6 +603,12 @@ impl<F: Primitive, I: InitializationStrategy> KMeansBuilder<F, BackendNotSet, Al
585603 /// .unwrap();
586604 /// assert_eq!(model.k(), 2);
587605 /// ```
606+ ///
607+ /// # Notes
608+ ///
609+ /// When calling `fit()` on the resulting config, the input `points` must contain only
610+ /// finite floating-point values. Passing `NaN` or `Infinity` may result in undefined
611+ /// behavior or failure to produce a valid model.
588612 #[ inline]
589613 pub fn build_default ( self ) -> KMeansConfig < F , CPUScalar , Euclidean , false , I > {
590614 self . cpu_scalar ( ) . euclidean ( ) . build ( )
@@ -940,7 +964,7 @@ fn validate_centroid_shape<F: Primitive>(centroids: &[F], ncols: usize, k: usize
940964 "number of centroids must be greater than zero" . into ( ) ,
941965 ) ) ;
942966 }
943- if centroids . len ( ) != ncols . saturating_mul ( k ) {
967+ if ncols . checked_mul ( k ) != Some ( centroids . len ( ) ) {
944968 return Err ( KMeansError :: InvalidInput (
945969 "centroids length must equal k * ncols" . into ( ) ,
946970 ) ) ;
@@ -970,6 +994,13 @@ fn validate_inputs<F: Primitive>(points: &[F], ncols: usize, k: usize) -> Result
970994 "points must contain at least one row" . into ( ) ,
971995 ) ) ;
972996 }
997+ let npoints = points. len ( ) / ncols;
998+ if k > npoints {
999+ return Err ( KMeansError :: InvalidInput ( format ! (
1000+ "number of clusters k ({}) cannot be greater than number of points ({})" ,
1001+ k, npoints
1002+ ) ) ) ;
1003+ }
9731004 Ok ( ( ) )
9741005}
9751006
@@ -993,6 +1024,71 @@ fn validate_source_inputs<F: Primitive, S: PointSource<F>>(source: &S, k: usize)
9931024 "number of centroids must be greater than zero" . into ( ) ,
9941025 ) ) ;
9951026 }
1027+ let npoints = source. num_points ( ) ;
1028+ if npoints == 0 {
1029+ return Err ( KMeansError :: InvalidInput (
1030+ "point source must contain at least one point" . into ( ) ,
1031+ ) ) ;
1032+ }
1033+ if k > npoints {
1034+ return Err ( KMeansError :: InvalidInput ( format ! (
1035+ "number of clusters k ({}) cannot be greater than number of points ({})" ,
1036+ k, npoints
1037+ ) ) ) ;
1038+ }
1039+ if F :: from ( npoints) . is_none ( ) {
1040+ return Err ( KMeansError :: InvalidInput ( format ! (
1041+ "number of points ({}) cannot be represented in the chosen floating point type" ,
1042+ npoints
1043+ ) ) ) ;
1044+ }
1045+ Ok ( ( ) )
1046+ }
1047+
1048+ /// Validates input dimensions for prediction/transform operations.
1049+ ///
1050+ /// Unlike training validation, this does NOT enforce `k <= n_points` because
1051+ /// a trained model can predict cluster assignments for any number of points (even 1).
1052+ #[ inline]
1053+ fn validate_prediction_inputs < F : Primitive > ( points : & [ F ] , ncols : usize ) -> Result < ( ) > {
1054+ if ncols == 0 {
1055+ return Err ( KMeansError :: InvalidInput (
1056+ "number of columns must be greater than zero" . into ( ) ,
1057+ ) ) ;
1058+ }
1059+ if !points. len ( ) . is_multiple_of ( ncols) {
1060+ return Err ( KMeansError :: InvalidInput (
1061+ "points length must be divisible by ncols" . into ( ) ,
1062+ ) ) ;
1063+ }
1064+ if points. is_empty ( ) {
1065+ return Err ( KMeansError :: InvalidInput (
1066+ "points must contain at least one row" . into ( ) ,
1067+ ) ) ;
1068+ }
1069+ Ok ( ( ) )
1070+ }
1071+
1072+ /// Validates source input dimensions for prediction/transform operations.
1073+ ///
1074+ /// Unlike training validation, this does NOT enforce `k <= n_points` because
1075+ /// a trained model can predict cluster assignments for any number of points (even 1).
1076+ fn validate_prediction_source_inputs < F : Primitive , S : PointSource < F > > (
1077+ source : & S ,
1078+ expected_ncols : usize ,
1079+ ) -> Result < ( ) > {
1080+ if source. num_columns ( ) == 0 {
1081+ return Err ( KMeansError :: InvalidInput (
1082+ "number of columns must be greater than zero" . into ( ) ,
1083+ ) ) ;
1084+ }
1085+ if source. num_columns ( ) != expected_ncols {
1086+ return Err ( KMeansError :: DimensionMismatch ( format ! (
1087+ "source has {} columns but model expects {}" ,
1088+ source. num_columns( ) ,
1089+ expected_ncols
1090+ ) ) ) ;
1091+ }
9961092 if source. num_points ( ) == 0 {
9971093 return Err ( KMeansError :: InvalidInput (
9981094 "point source must contain at least one point" . into ( ) ,
0 commit comments