@@ -211,6 +211,12 @@ impl<F: Primitive> KMeans<F> {
211211 /// let labels = model.predict(&data).unwrap();
212212 /// assert_eq!(labels.len(), 2);
213213 /// ```
214+ ///
215+ /// # Notes
216+ ///
217+ /// This method assumes that the input `points` contains only finite floating-point values.
218+ /// Accessing `NaN` or `Infinity` in the input may result in undefined behavior or
219+ /// failure to produce a model.
214220 pub fn fit_default_scalar(points: impl AsRef<[F]>, ncols: usize, k: usize) -> Result<Self> {
215221 KMeansBuilder::new(k)
216222 .cpu_scalar()
@@ -235,6 +241,12 @@ impl<F: Primitive> KMeans<F> {
235241 /// let labels = model.predict_sequential(&data).unwrap();
236242 /// assert_eq!(labels.len(), 2);
237243 /// ```
244+ ///
245+ /// # Notes
246+ ///
247+ /// This method assumes that the input `points` contains only finite floating-point values.
248+ /// Passing `NaN` or `Infinity` in the input may result in undefined behavior or
249+ /// incorrect predictions.
238250 pub fn predict_sequential(&self, points: impl AsRef<[F]>) -> Result<Vec<usize>> {
239251 self.predict_with_backend_sequential::<CPUScalar>(points)
240252 }
@@ -261,7 +273,7 @@ impl<F: Primitive> KMeans<F> {
261273 ) -> Result<Vec<usize>> {
262274 self.validate_model_shape()?;
263275 let points = points.as_ref();
264- validate_inputs (points, self.ncols, self.k )?;
276+ validate_prediction_inputs (points, self.ncols)?;
265277
266278 let npoints = points.len() / self.ncols;
267279 let mut labels = vec![0usize; npoints];
@@ -347,7 +359,7 @@ impl<F: Primitive> KMeans<F> {
347359 source: &S,
348360 ) -> Result<Vec<usize>> {
349361 self.validate_model_shape()?;
350- validate_source_inputs (source, self.k )?;
362+ validate_prediction_source_inputs (source, self.ncols )?;
351363 let npoints = source.num_points();
352364 let mut labels = vec![0usize; npoints];
353365
@@ -393,10 +405,16 @@ impl<F: Primitive> KMeans<F> {
393405 /// For `MetricType::Euclidean` this returns squared Euclidean distances,
394406 /// for `MetricType::DotProduct` it returns raw dot-product similarities.
395407 /// Returns a flat vector of shape (n_points * k).
408+ ///
409+ /// # Notes
410+ ///
411+ /// This method assumes that the input `points` contains only finite floating-point values.
412+ /// Passing `NaN` or `Infinity` in the input may result in undefined behavior or
413+ /// incorrect results.
396414 pub fn transform(&self, points: impl AsRef<[F]>) -> Result<Vec<F>> {
397415 self.validate_model_shape()?;
398416 let points = points.as_ref();
399- validate_inputs (points, self.ncols, self.k )?;
417+ validate_prediction_inputs (points, self.ncols)?;
400418
401419 let npoints = points.len() / self.ncols;
402420 let mut distances = vec![F::zero(); npoints * self.k];
@@ -556,7 +574,7 @@ impl<F: Primitive> KMeansBuilder<F> {
556574 k,
557575 iterations: 100,
558576 attempts: 1,
559- tolerance: F::from(1e-4).unwrap_or(F::zero ()),
577+ tolerance: F::from(1e-4).unwrap_or(F::epsilon ()),
560578 seed: None,
561579 mini_batch_rel_tolerance: kmeans_mini_batch::DEFAULT_MINI_BATCH_REL_TOL,
562580 mini_batch_min_iterations: kmeans_mini_batch::DEFAULT_MINI_BATCH_MIN_ITERATIONS,
@@ -585,6 +603,12 @@ impl<F: Primitive, I: InitializationStrategy> KMeansBuilder<F, BackendNotSet, Al
585603 /// .unwrap();
586604 /// assert_eq!(model.k(), 2);
587605 /// ```
606+ ///
607+ /// # Notes
608+ ///
609+ /// When calling `fit()` on the resulting config, the input `points` must contain only
610+ /// finite floating-point values. Passing `NaN` or `Infinity` may result in undefined
611+ /// behavior or failure to produce a valid model.
588612 #[inline]
589613 pub fn build_default(self) -> KMeansConfig<F, CPUScalar, Euclidean, false, I> {
590614 self.cpu_scalar().euclidean().build()
@@ -940,7 +964,7 @@ fn validate_centroid_shape<F: Primitive>(centroids: &[F], ncols: usize, k: usize
940964 "number of centroids must be greater than zero".into(),
941965 ));
942966 }
943- if centroids.len( ) != ncols.saturating_mul(k ) {
967+ if ncols.checked_mul(k ) != Some(centroids.len() ) {
944968 return Err(KMeansError::InvalidInput(
945969 "centroids length must equal k * ncols".into(),
946970 ));
@@ -970,6 +994,13 @@ fn validate_inputs<F: Primitive>(points: &[F], ncols: usize, k: usize) -> Result
970994 "points must contain at least one row".into(),
971995 ));
972996 }
997+ let npoints = points.len() / ncols;
998+ if k > npoints {
999+ return Err(KMeansError::InvalidInput(format!(
1000+ "number of clusters k ({}) cannot be greater than number of points ({})",
1001+ k, npoints
1002+ )));
1003+ }
9731004 Ok(())
9741005}
9751006
@@ -993,6 +1024,71 @@ fn validate_source_inputs<F: Primitive, S: PointSource<F>>(source: &S, k: usize)
9931024 "number of centroids must be greater than zero".into(),
9941025 ));
9951026 }
1027+ let npoints = source.num_points();
1028+ if npoints == 0 {
1029+ return Err(KMeansError::InvalidInput(
1030+ "point source must contain at least one point".into(),
1031+ ));
1032+ }
1033+ if k > npoints {
1034+ return Err(KMeansError::InvalidInput(format!(
1035+ "number of clusters k ({}) cannot be greater than number of points ({})",
1036+ k, npoints
1037+ )));
1038+ }
1039+ if F::from(npoints).is_none() {
1040+ return Err(KMeansError::InvalidInput(format!(
1041+ "number of points ({}) cannot be represented in the chosen floating point type",
1042+ npoints
1043+ )));
1044+ }
1045+ Ok(())
1046+ }
1047+
1048+ /// Validates input dimensions for prediction/transform operations.
1049+ ///
1050+ /// Unlike training validation, this does NOT enforce `k <= n_points` because
1051+ /// a trained model can predict cluster assignments for any number of points (even 1).
1052+ #[inline]
1053+ fn validate_prediction_inputs<F: Primitive>(points: &[F], ncols: usize) -> Result<()> {
1054+ if ncols == 0 {
1055+ return Err(KMeansError::InvalidInput(
1056+ "number of columns must be greater than zero".into(),
1057+ ));
1058+ }
1059+ if !points.len().is_multiple_of(ncols) {
1060+ return Err(KMeansError::InvalidInput(
1061+ "points length must be divisible by ncols".into(),
1062+ ));
1063+ }
1064+ if points.is_empty() {
1065+ return Err(KMeansError::InvalidInput(
1066+ "points must contain at least one row".into(),
1067+ ));
1068+ }
1069+ Ok(())
1070+ }
1071+
1072+ /// Validates source input dimensions for prediction/transform operations.
1073+ ///
1074+ /// Unlike training validation, this does NOT enforce `k <= n_points` because
1075+ /// a trained model can predict cluster assignments for any number of points (even 1).
1076+ fn validate_prediction_source_inputs<F: Primitive, S: PointSource<F>>(
1077+ source: &S,
1078+ expected_ncols: usize,
1079+ ) -> Result<()> {
1080+ if source.num_columns() == 0 {
1081+ return Err(KMeansError::InvalidInput(
1082+ "number of columns must be greater than zero".into(),
1083+ ));
1084+ }
1085+ if source.num_columns() != expected_ncols {
1086+ return Err(KMeansError::DimensionMismatch(format!(
1087+ "source has {} columns but model expects {}",
1088+ source.num_columns(),
1089+ expected_ncols
1090+ )));
1091+ }
9961092 if source.num_points() == 0 {
9971093 return Err(KMeansError::InvalidInput(
9981094 "point source must contain at least one point".into(),
0 commit comments