Skip to content

Commit 487f041

Browse files
committed
tests: add property based tests
1 parent eb308e5 commit 487f041

File tree

2 files changed

+307
-1
lines changed

2 files changed

+307
-1
lines changed

crates/approx-chol/tests/common/laplacian_prop.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use proptest::prelude::*;
2+
use std::collections::VecDeque;
23

34
pub type LaplacianCsr = (Vec<u32>, Vec<u32>, Vec<f64>, u32);
45

@@ -54,3 +55,95 @@ pub fn rhs_for_dimension(n: usize) -> Vec<f64> {
5455
}
5556
rhs
5657
}
58+
59+
#[allow(dead_code)]
60+
pub fn csr_matvec(row_ptrs: &[u32], col_indices: &[u32], values: &[f64], x: &[f64]) -> Vec<f64> {
61+
let n = row_ptrs.len() - 1;
62+
let mut y = vec![0.0; n];
63+
for i in 0..n {
64+
let start = row_ptrs[i] as usize;
65+
let end = row_ptrs[i + 1] as usize;
66+
for k in start..end {
67+
y[i] += values[k] * x[col_indices[k] as usize];
68+
}
69+
}
70+
y
71+
}
72+
73+
#[allow(dead_code)]
74+
pub fn is_connected(row_ptrs: &[u32], col_indices: &[u32], n: u32) -> bool {
75+
if n <= 1 {
76+
return true;
77+
}
78+
let n = n as usize;
79+
let mut visited = vec![false; n];
80+
let mut queue = VecDeque::new();
81+
visited[0] = true;
82+
queue.push_back(0usize);
83+
while let Some(v) = queue.pop_front() {
84+
let start = row_ptrs[v] as usize;
85+
let end = row_ptrs[v + 1] as usize;
86+
for &col in &col_indices[start..end] {
87+
let u = col as usize;
88+
if !visited[u] {
89+
visited[u] = true;
90+
queue.push_back(u);
91+
}
92+
}
93+
}
94+
visited.iter().all(|&v| v)
95+
}
96+
97+
#[allow(dead_code)]
98+
pub fn norm2(v: &[f64]) -> f64 {
99+
v.iter().map(|x| x * x).sum::<f64>().sqrt()
100+
}
101+
102+
#[allow(dead_code)]
103+
pub fn random_zero_sum_rhs_strategy(n: usize) -> BoxedStrategy<Vec<f64>> {
104+
if n <= 1 {
105+
Just(vec![0.0; n]).boxed()
106+
} else {
107+
prop::collection::vec(-10.0f64..10.0, n)
108+
.prop_map(|mut v| {
109+
let mean = v.iter().sum::<f64>() / v.len() as f64;
110+
for x in &mut v {
111+
*x -= mean;
112+
}
113+
v
114+
})
115+
.boxed()
116+
}
117+
}
118+
119+
#[allow(dead_code)]
120+
pub fn laplacian_with_rhs_strategy() -> impl Strategy<Value = (LaplacianCsr, Vec<f64>)> {
121+
laplacian_csr_strategy().prop_flat_map(|(rp, ci, vals, n)| {
122+
random_zero_sum_rhs_strategy(n as usize)
123+
.prop_map(move |rhs| ((rp.clone(), ci.clone(), vals.clone(), n), rhs))
124+
})
125+
}
126+
127+
#[allow(dead_code)]
128+
pub fn sddm_csr_strategy() -> impl Strategy<Value = LaplacianCsr> {
129+
(1usize..=8).prop_flat_map(|n| {
130+
let pair_count = n * (n - 1) / 2;
131+
(
132+
prop::collection::vec(0u8..=4, pair_count),
133+
prop::collection::vec(1u8..=5, n),
134+
)
135+
.prop_map(move |(edge_weights, surpluses)| {
136+
let (rp, ci, mut vals, n_u32) = build_laplacian_csr(n, &edge_weights);
137+
for i in 0..n {
138+
let start = rp[i] as usize;
139+
let end = rp[i + 1] as usize;
140+
for k in start..end {
141+
if ci[k] as usize == i {
142+
vals[k] += surpluses[i] as f64;
143+
}
144+
}
145+
}
146+
(rp, ci, vals, n_u32)
147+
})
148+
})
149+
}

crates/approx-chol/tests/property_factorization.rs

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@ use panic_ok::OrPanic;
55
mod laplacian_prop;
66

77
use approx_chol::{factorize, factorize_with, Config, CsrRef};
8-
use laplacian_prop::{laplacian_csr_strategy, rhs_for_dimension};
8+
use laplacian_prop::{
9+
csr_matvec, is_connected, laplacian_csr_strategy, laplacian_with_rhs_strategy, norm2,
10+
rhs_for_dimension, sddm_csr_strategy,
11+
};
912
use proptest::prelude::*;
1013
use std::panic::{catch_unwind, AssertUnwindSafe};
1114

1215
proptest! {
16+
// -----------------------------------------------------------------------
17+
// Existing: panic-freedom and finiteness
18+
// -----------------------------------------------------------------------
19+
1320
#[test]
1421
fn default_factorization_solve_is_panic_free_and_finite(
1522
(row_ptrs, col_indices, values, n) in laplacian_csr_strategy()
@@ -79,4 +86,210 @@ proptest! {
7986
let work = run.or_panic("checked above");
8087
prop_assert!(work.iter().all(|x| x.is_finite()));
8188
}
89+
90+
// -----------------------------------------------------------------------
91+
// Solution quality: residual ||Ax - b|| / ||b|| is bounded
92+
// -----------------------------------------------------------------------
93+
94+
#[test]
95+
fn ac_residual_is_bounded(
96+
(row_ptrs, col_indices, values, n) in laplacian_csr_strategy()
97+
) {
98+
prop_assume!(n >= 2);
99+
prop_assume!(is_connected(&row_ptrs, &col_indices, n));
100+
101+
let csr = CsrRef::new(&row_ptrs, &col_indices, &values, n)
102+
.or_panic("valid CSR");
103+
let factor = factorize(csr).or_panic("factorization");
104+
let rhs = rhs_for_dimension(n as usize);
105+
let x = factor.solve(&rhs).or_panic("solve");
106+
107+
let ax = csr_matvec(&row_ptrs, &col_indices, &values, &x);
108+
let residual: Vec<f64> = ax.iter().zip(rhs.iter()).map(|(a, b)| a - b).collect();
109+
let r_norm = norm2(&residual);
110+
let b_norm = norm2(&rhs);
111+
112+
prop_assert!(
113+
b_norm < 1e-15 || r_norm / b_norm < 100.0,
114+
"AC relative residual too large: {:.4e} (r_norm={:.4e}, b_norm={:.4e})",
115+
r_norm / b_norm, r_norm, b_norm
116+
);
117+
}
118+
119+
#[test]
120+
fn ac2_residual_is_bounded(
121+
(row_ptrs, col_indices, values, n) in laplacian_csr_strategy()
122+
) {
123+
prop_assume!(n >= 2);
124+
prop_assume!(is_connected(&row_ptrs, &col_indices, n));
125+
126+
let csr = CsrRef::new(&row_ptrs, &col_indices, &values, n)
127+
.or_panic("valid CSR");
128+
let factor = factorize_with(
129+
csr,
130+
Config {
131+
seed: 7,
132+
split_merge: Some(2),
133+
},
134+
)
135+
.or_panic("AC2 factorization");
136+
let rhs = rhs_for_dimension(n as usize);
137+
let x = factor.solve(&rhs).or_panic("solve");
138+
139+
let ax = csr_matvec(&row_ptrs, &col_indices, &values, &x);
140+
let residual: Vec<f64> = ax.iter().zip(rhs.iter()).map(|(a, b)| a - b).collect();
141+
let r_norm = norm2(&residual);
142+
let b_norm = norm2(&rhs);
143+
144+
prop_assert!(
145+
b_norm < 1e-15 || r_norm / b_norm < 100.0,
146+
"AC2 relative residual too large: {:.4e} (r_norm={:.4e}, b_norm={:.4e})",
147+
r_norm / b_norm, r_norm, b_norm
148+
);
149+
}
150+
151+
#[test]
152+
fn random_rhs_residual_is_bounded(
153+
((row_ptrs, col_indices, values, n), rhs) in laplacian_with_rhs_strategy()
154+
) {
155+
prop_assume!(n >= 2);
156+
prop_assume!(is_connected(&row_ptrs, &col_indices, n));
157+
let b_norm = norm2(&rhs);
158+
prop_assume!(b_norm > 1e-10);
159+
160+
let csr = CsrRef::new(&row_ptrs, &col_indices, &values, n)
161+
.or_panic("valid CSR");
162+
let factor = factorize(csr).or_panic("factorization");
163+
let x = factor.solve(&rhs).or_panic("solve");
164+
165+
let ax = csr_matvec(&row_ptrs, &col_indices, &values, &x);
166+
let residual: Vec<f64> = ax.iter().zip(rhs.iter()).map(|(a, b)| a - b).collect();
167+
let r_norm = norm2(&residual);
168+
169+
prop_assert!(
170+
r_norm / b_norm < 100.0,
171+
"random-RHS relative residual too large: {:.4e}", r_norm / b_norm
172+
);
173+
}
174+
175+
// -----------------------------------------------------------------------
176+
// SDDM matrices (Gremban augmentation path)
177+
// -----------------------------------------------------------------------
178+
179+
#[test]
180+
fn sddm_factorization_is_panic_free_and_finite(
181+
(row_ptrs, col_indices, values, n) in sddm_csr_strategy()
182+
) {
183+
let run = catch_unwind(AssertUnwindSafe(|| {
184+
let csr = CsrRef::new(&row_ptrs, &col_indices, &values, n)
185+
.or_panic("valid SDDM CSR");
186+
let factor = factorize(csr).or_panic("factorization");
187+
let mut rhs = vec![0.0_f64; n as usize];
188+
if n >= 2 {
189+
rhs[0] = 1.0;
190+
rhs[(n as usize) - 1] = -1.0;
191+
}
192+
factor.solve(&rhs).or_panic("solve")
193+
}));
194+
195+
prop_assert!(run.is_ok(), "SDDM factorization or solve panicked");
196+
let x = run.or_panic("checked above");
197+
prop_assert!(x.iter().all(|v| v.is_finite()), "SDDM solution has non-finite values");
198+
}
199+
200+
// -----------------------------------------------------------------------
201+
// Determinism: same seed + same input → identical output
202+
// -----------------------------------------------------------------------
203+
204+
#[test]
205+
fn deterministic_with_fixed_seed(
206+
(row_ptrs, col_indices, values, n) in laplacian_csr_strategy()
207+
) {
208+
let config = Config { seed: 42, ..Default::default() };
209+
let rhs = rhs_for_dimension(n as usize);
210+
211+
let csr1 = CsrRef::new(&row_ptrs, &col_indices, &values, n)
212+
.or_panic("valid CSR");
213+
let x1 = factorize_with(csr1, config).or_panic("factorize 1")
214+
.solve(&rhs).or_panic("solve 1");
215+
216+
let csr2 = CsrRef::new(&row_ptrs, &col_indices, &values, n)
217+
.or_panic("valid CSR");
218+
let x2 = factorize_with(csr2, config).or_panic("factorize 2")
219+
.solve(&rhs).or_panic("solve 2");
220+
221+
prop_assert_eq!(x1.len(), x2.len());
222+
for (a, b) in x1.iter().zip(x2.iter()) {
223+
prop_assert!(
224+
a.to_bits() == b.to_bits(),
225+
"non-deterministic: {} vs {}", a, b
226+
);
227+
}
228+
}
229+
230+
// -----------------------------------------------------------------------
231+
// Factor dimensions are consistent
232+
// -----------------------------------------------------------------------
233+
234+
#[test]
235+
fn factor_dimensions_are_consistent(
236+
(row_ptrs, col_indices, values, n) in laplacian_csr_strategy()
237+
) {
238+
let csr = CsrRef::new(&row_ptrs, &col_indices, &values, n)
239+
.or_panic("valid CSR");
240+
let factor = factorize(csr).or_panic("factorization");
241+
242+
prop_assert_eq!(
243+
factor.original_n(), n as usize,
244+
"original_n must match input dimension"
245+
);
246+
prop_assert!(
247+
factor.n() >= n as usize,
248+
"factor.n() must be >= input dimension"
249+
);
250+
// Pure Laplacians should not be augmented
251+
prop_assert_eq!(
252+
factor.n(), n as usize,
253+
"pure Laplacian should not trigger Gremban augmentation"
254+
);
255+
}
256+
257+
#[test]
258+
fn sddm_factor_dimensions_are_consistent(
259+
(row_ptrs, col_indices, values, n) in sddm_csr_strategy()
260+
) {
261+
let csr = CsrRef::new(&row_ptrs, &col_indices, &values, n)
262+
.or_panic("valid SDDM CSR");
263+
let factor = factorize(csr).or_panic("factorization");
264+
265+
prop_assert_eq!(
266+
factor.original_n(), n as usize,
267+
"original_n must match input dimension"
268+
);
269+
prop_assert!(
270+
factor.n() > n as usize,
271+
"SDDM should trigger Gremban augmentation (factor.n() must be > n)"
272+
);
273+
}
274+
275+
// -----------------------------------------------------------------------
276+
// f32 support
277+
// -----------------------------------------------------------------------
278+
279+
#[test]
280+
fn f32_factorization_is_finite(
281+
(row_ptrs, col_indices, values_f64, n) in laplacian_csr_strategy()
282+
) {
283+
let values_f32: Vec<f32> = values_f64.iter().map(|&v| v as f32).collect();
284+
let csr = CsrRef::new(&row_ptrs, &col_indices, &values_f32, n)
285+
.or_panic("valid f32 CSR");
286+
let factor = factorize(csr).or_panic("f32 factorization");
287+
let mut rhs = vec![0.0_f32; n as usize];
288+
if n >= 2 {
289+
rhs[0] = 1.0;
290+
rhs[(n as usize) - 1] = -1.0;
291+
}
292+
let x = factor.solve(&rhs).or_panic("f32 solve");
293+
prop_assert!(x.iter().all(|v| v.is_finite()), "f32 solution has non-finite values");
294+
}
82295
}

0 commit comments

Comments
 (0)