Skip to content

Commit e792add

Browse files
committed
Refactor validation logic and update CI
1 parent a18aa1e commit e792add

File tree

10 files changed

+176
-16
lines changed

10 files changed

+176
-16
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ jobs:
3636
run: cargo check --all-targets
3737
- name: cargo check (no default features)
3838
run: cargo check --all-targets --no-default-features
39+
- name: cargo check (serde)
40+
run: cargo check --all-targets --features serde
3941

4042
test:
4143
name: Cargo Test

src/kmeans_core_scalar.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ impl<F: Primitive> CoreBackend<F> for ScalarBackend {
3737
for c in 0..counts.len() {
3838
if counts[c] > 0 {
3939
let base = c * ncols;
40-
let inv_count = F::one() / F::from(counts[c]).unwrap_or(F::one());
40+
let count_f = F::from_usize(counts[c]);
41+
let inv_count = F::one() / count_f;
4142
for j in 0..ncols {
4243
centroids[base + j] = sums[base + j] * inv_count;
4344
}

src/kmeans_core_simd.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::kmeans_core_common::{
44
find_nearest_centroids_generic,
55
};
66
use crate::point_source::PointSource;
7+
use crate::primitive::Primitive;
78

89
pub struct SimdBackend;
910

@@ -476,7 +477,7 @@ macro_rules! impl_simd_backend {
476477
}
477478

478479
if counts[k_idx] > 0 {
479-
inv_counts[lane] = 1.0 / counts[k_idx] as $scalar;
480+
inv_counts[lane] = 1.0 / <$scalar>::from_usize(counts[k_idx]);
480481
active_mask[lane] = true;
481482
} else if source.num_points() > 0 {
482483
zero_indices[lane] = rng.random_range(0..source.num_points());

src/kmeans_cpu.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ pub(crate) fn run<
2424
tolerance: F,
2525
seed: Option<u64>,
2626
) -> Result<(Vec<F>, F)> {
27-
let _npoints = source.num_points();
2827
let ncols = source.num_columns();
2928
let mut rng = match seed {
3029
Some(s) => StdRng::seed_from_u64(s),
@@ -46,7 +45,7 @@ pub(crate) fn run<
4645
par_chunk_size,
4746
)?;
4847

49-
inertia = F::from(iter_inertia).unwrap_or(F::zero());
48+
inertia = F::from(iter_inertia).ok_or(crate::error::Error::ConversionFailure)?;
5049

5150
let old_centroids = C::finalize_centroids(&prepared_centroids, ncols, k);
5251

src/kmeans_mini_batch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ pub(crate) fn run<
139139
let (_, _, full_inertia) =
140140
E::compute_stats_full::<F, C, M, S>(source, ncols, k, &prepared_centroids, par_chunk)?;
141141
let centroids = C::finalize_centroids(&prepared_centroids, ncols, k);
142-
let inertia = F::from(full_inertia).unwrap_or(F::zero());
142+
let inertia = F::from(full_inertia).ok_or(KMeansError::ConversionFailure)?;
143143

144144
Ok((centroids, inertia))
145145
}

src/lib.rs

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ impl<F: Primitive> KMeans<F> {
211211
/// let labels = model.predict(&data).unwrap();
212212
/// assert_eq!(labels.len(), 2);
213213
/// ```
214+
///
215+
/// # Notes
216+
///
217+
/// This method assumes that the input `points` contains only finite floating-point values.
218+
/// Accessing `NaN` or `Infinity` in the input may result in undefined behavior or
219+
/// failure to produce a model.
214220
pub fn fit_default_scalar(points: impl AsRef<[F]>, ncols: usize, k: usize) -> Result<Self> {
215221
KMeansBuilder::new(k)
216222
.cpu_scalar()
@@ -235,6 +241,12 @@ impl<F: Primitive> KMeans<F> {
235241
/// let labels = model.predict_sequential(&data).unwrap();
236242
/// assert_eq!(labels.len(), 2);
237243
/// ```
244+
///
245+
/// # Notes
246+
///
247+
/// This method assumes that the input `points` contains only finite floating-point values.
248+
/// Passing `NaN` or `Infinity` in the input may result in undefined behavior or
249+
/// incorrect predictions.
238250
pub fn predict_sequential(&self, points: impl AsRef<[F]>) -> Result<Vec<usize>> {
239251
self.predict_with_backend_sequential::<CPUScalar>(points)
240252
}
@@ -261,7 +273,7 @@ impl<F: Primitive> KMeans<F> {
261273
) -> Result<Vec<usize>> {
262274
self.validate_model_shape()?;
263275
let points = points.as_ref();
264-
validate_inputs(points, self.ncols, self.k)?;
276+
validate_prediction_inputs(points, self.ncols)?;
265277

266278
let npoints = points.len() / self.ncols;
267279
let mut labels = vec![0usize; npoints];
@@ -347,7 +359,7 @@ impl<F: Primitive> KMeans<F> {
347359
source: &S,
348360
) -> Result<Vec<usize>> {
349361
self.validate_model_shape()?;
350-
validate_source_inputs(source, self.k)?;
362+
validate_prediction_source_inputs(source, self.ncols)?;
351363
let npoints = source.num_points();
352364
let mut labels = vec![0usize; npoints];
353365

@@ -393,10 +405,16 @@ impl<F: Primitive> KMeans<F> {
393405
/// For `MetricType::Euclidean` this returns squared Euclidean distances,
394406
/// for `MetricType::DotProduct` it returns raw dot-product similarities.
395407
/// Returns a flat vector of shape (n_points * k).
408+
///
409+
/// # Notes
410+
///
411+
/// This method assumes that the input `points` contains only finite floating-point values.
412+
/// Passing `NaN` or `Infinity` in the input may result in undefined behavior or
413+
/// incorrect results.
396414
pub fn transform(&self, points: impl AsRef<[F]>) -> Result<Vec<F>> {
397415
self.validate_model_shape()?;
398416
let points = points.as_ref();
399-
validate_inputs(points, self.ncols, self.k)?;
417+
validate_prediction_inputs(points, self.ncols)?;
400418

401419
let npoints = points.len() / self.ncols;
402420
let mut distances = vec![F::zero(); npoints * self.k];
@@ -556,7 +574,7 @@ impl<F: Primitive> KMeansBuilder<F> {
556574
k,
557575
iterations: 100,
558576
attempts: 1,
559-
tolerance: F::from(1e-4).unwrap_or(F::zero()),
577+
tolerance: F::from(1e-4).unwrap_or(F::epsilon()),
560578
seed: None,
561579
mini_batch_rel_tolerance: kmeans_mini_batch::DEFAULT_MINI_BATCH_REL_TOL,
562580
mini_batch_min_iterations: kmeans_mini_batch::DEFAULT_MINI_BATCH_MIN_ITERATIONS,
@@ -585,6 +603,12 @@ impl<F: Primitive, I: InitializationStrategy> KMeansBuilder<F, BackendNotSet, Al
585603
/// .unwrap();
586604
/// assert_eq!(model.k(), 2);
587605
/// ```
606+
///
607+
/// # Notes
608+
///
609+
/// When calling `fit()` on the resulting config, the input `points` must contain only
610+
/// finite floating-point values. Passing `NaN` or `Infinity` may result in undefined
611+
/// behavior or failure to produce a valid model.
588612
#[inline]
589613
pub fn build_default(self) -> KMeansConfig<F, CPUScalar, Euclidean, false, I> {
590614
self.cpu_scalar().euclidean().build()
@@ -940,7 +964,7 @@ fn validate_centroid_shape<F: Primitive>(centroids: &[F], ncols: usize, k: usize
940964
"number of centroids must be greater than zero".into(),
941965
));
942966
}
943-
if centroids.len() != ncols.saturating_mul(k) {
967+
if ncols.checked_mul(k) != Some(centroids.len()) {
944968
return Err(KMeansError::InvalidInput(
945969
"centroids length must equal k * ncols".into(),
946970
));
@@ -970,6 +994,13 @@ fn validate_inputs<F: Primitive>(points: &[F], ncols: usize, k: usize) -> Result
970994
"points must contain at least one row".into(),
971995
));
972996
}
997+
let npoints = points.len() / ncols;
998+
if k > npoints {
999+
return Err(KMeansError::InvalidInput(format!(
1000+
"number of clusters k ({}) cannot be greater than number of points ({})",
1001+
k, npoints
1002+
)));
1003+
}
9731004
Ok(())
9741005
}
9751006

@@ -993,6 +1024,71 @@ fn validate_source_inputs<F: Primitive, S: PointSource<F>>(source: &S, k: usize)
9931024
"number of centroids must be greater than zero".into(),
9941025
));
9951026
}
1027+
let npoints = source.num_points();
1028+
if npoints == 0 {
1029+
return Err(KMeansError::InvalidInput(
1030+
"point source must contain at least one point".into(),
1031+
));
1032+
}
1033+
if k > npoints {
1034+
return Err(KMeansError::InvalidInput(format!(
1035+
"number of clusters k ({}) cannot be greater than number of points ({})",
1036+
k, npoints
1037+
)));
1038+
}
1039+
if F::from(npoints).is_none() {
1040+
return Err(KMeansError::InvalidInput(format!(
1041+
"number of points ({}) cannot be represented in the chosen floating point type",
1042+
npoints
1043+
)));
1044+
}
1045+
Ok(())
1046+
}
1047+
1048+
/// Validates input dimensions for prediction/transform operations.
1049+
///
1050+
/// Unlike training validation, this does NOT enforce `k <= n_points` because
1051+
/// a trained model can predict cluster assignments for any number of points (even 1).
1052+
#[inline]
1053+
fn validate_prediction_inputs<F: Primitive>(points: &[F], ncols: usize) -> Result<()> {
1054+
if ncols == 0 {
1055+
return Err(KMeansError::InvalidInput(
1056+
"number of columns must be greater than zero".into(),
1057+
));
1058+
}
1059+
if !points.len().is_multiple_of(ncols) {
1060+
return Err(KMeansError::InvalidInput(
1061+
"points length must be divisible by ncols".into(),
1062+
));
1063+
}
1064+
if points.is_empty() {
1065+
return Err(KMeansError::InvalidInput(
1066+
"points must contain at least one row".into(),
1067+
));
1068+
}
1069+
Ok(())
1070+
}
1071+
1072+
/// Validates source input dimensions for prediction/transform operations.
1073+
///
1074+
/// Unlike training validation, this does NOT enforce `k <= n_points` because
1075+
/// a trained model can predict cluster assignments for any number of points (even 1).
1076+
fn validate_prediction_source_inputs<F: Primitive, S: PointSource<F>>(
1077+
source: &S,
1078+
expected_ncols: usize,
1079+
) -> Result<()> {
1080+
if source.num_columns() == 0 {
1081+
return Err(KMeansError::InvalidInput(
1082+
"number of columns must be greater than zero".into(),
1083+
));
1084+
}
1085+
if source.num_columns() != expected_ncols {
1086+
return Err(KMeansError::DimensionMismatch(format!(
1087+
"source has {} columns but model expects {}",
1088+
source.num_columns(),
1089+
expected_ncols
1090+
)));
1091+
}
9961092
if source.num_points() == 0 {
9971093
return Err(KMeansError::InvalidInput(
9981094
"point source must contain at least one point".into(),

src/primitive.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
use num_traits::Float;
22
use std::fmt::Debug;
33

4-
pub trait Primitive: Float + Debug + Send + Sync + 'static {}
5-
impl<T: Float + Debug + Send + Sync + 'static> Primitive for T {}
4+
pub trait Primitive: Float + Debug + Send + Sync + 'static {
5+
/// Converts a `usize` to this primitive type.
6+
///
7+
/// This method is preferred over `as` casting or generic `From` traits to ensure
8+
/// consistent behavior across backends and explicit handling of potential precision loss
9+
/// (though for k-means counts/indices, values are expected to fit).
10+
fn from_usize(n: usize) -> Self;
11+
}
12+
13+
impl Primitive for f32 {
14+
#[inline(always)]
15+
fn from_usize(n: usize) -> Self {
16+
n as f32
17+
}
18+
}
19+
20+
impl Primitive for f64 {
21+
#[inline(always)]
22+
fn from_usize(n: usize) -> Self {
23+
n as f64
24+
}
25+
}

src/wasm.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use js_sys::Uint32Array;
66
use wasm_bindgen::prelude::*;
77

88
macro_rules! wasm_model {
9-
($name:ident, $float:ty, $to_array:ident, $doc:literal) => {
9+
($name:ident, $float:ty, $doc:literal) => {
1010
#[wasm_bindgen]
1111
#[doc = $doc]
1212
pub struct $name {
@@ -106,14 +106,12 @@ macro_rules! wasm_model {
106106
wasm_model!(
107107
WasmModel,
108108
f32,
109-
to_array_f32,
110109
"Wasm-friendly K-Means++ model using f32 inputs/outputs."
111110
);
112111

113112
wasm_model!(
114113
WasmModelF64,
115114
f64,
116-
to_array_f64,
117115
"Wasm-friendly K-Means++ model using f64 inputs/outputs."
118116
);
119117

tests/api_compile_tests.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ fn test_builder_type_state_progression() {
132132

133133
#[test]
134134
fn test_mini_batch_api_compiles() {
135-
let points = [0.0_f32, 1.0_f32];
135+
// Need at least k points for k clusters
136+
let points = [0.0_f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
136137
let source = SlicePointSource::<f32>::new(&points, 1).unwrap();
137138
let _ = KMeansBuilder::<f32>::new(8)
138139
.cpu_scalar()

tests/validation_tests.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use kmeans_uni::{Error, KMeansBuilder};
2+
3+
#[test]
4+
fn test_k_greater_than_n_points_fails() {
5+
let data = vec![1.0f32; 10]; // 10 points
6+
let ncols = 1;
7+
let k = 11; // k > n
8+
9+
let result = KMeansBuilder::new(k).build_default().fit(&data, ncols);
10+
11+
assert!(result.is_err());
12+
match result {
13+
Err(Error::InvalidInput(msg)) => {
14+
assert!(msg.contains("cannot be greater than number of points"));
15+
}
16+
_ => panic!("Expected InvalidInput error"),
17+
}
18+
}
19+
20+
#[test]
21+
fn test_k_equals_n_points_succeeds() {
22+
let data = vec![1.0f32, 2.0, 3.0]; // 3 points
23+
let ncols = 1;
24+
let k = 3; // k == n
25+
26+
let result = KMeansBuilder::new(k).build_default().fit(&data, ncols);
27+
28+
if let Err(e) = result {
29+
panic!("Validation failed for k=n: {:?}", e);
30+
}
31+
}
32+
33+
#[test]
34+
fn test_checked_mul_overflow_theoretical() {
35+
let data = vec![0.0f32; 100];
36+
let ncols = 2;
37+
let k = 5;
38+
39+
let result = KMeansBuilder::new(k).build_default().fit(&data, ncols);
40+
41+
assert!(result.is_ok());
42+
}

0 commit comments

Comments
 (0)