diff --git a/kcl-ezpz/src/solver.rs b/kcl-ezpz/src/solver.rs index 90dc6247..099850dc 100644 --- a/kcl-ezpz/src/solver.rs +++ b/kcl-ezpz/src/solver.rs @@ -1,6 +1,9 @@ use std::sync::Mutex; -use faer::sparse::{Pair, SparseColMatRef, SymbolicSparseColMat, linalg::solvers::SymbolicLu}; +use faer::sparse::{ + Pair, SparseColMatRef, SymbolicSparseColMat, + linalg::{matmul::SparseMatMulInfo, solvers::SymbolicLu}, +}; use crate::{ Constraint, ConstraintEntry, NonLinearSystemError, Warning, WarningContent, @@ -96,6 +99,12 @@ pub(crate) struct Model<'c> { pub(crate) warnings: Mutex>, lambda_i: faer::sparse::SparseColMat, lu_symbolic: SymbolicLu, + jt_sym: SymbolicSparseColMat, + jtj_symbolic: (SymbolicSparseColMat, SparseMatMulInfo), + jt_value_indices: Vec, + a_sym: SymbolicSparseColMat, + a_from_jtj_indices: Vec, + lambda_diag_indices: Vec, } fn validate_variables( @@ -210,6 +219,12 @@ impl<'c> Model<'c> { // Precompute the symbolic LU of A = JᵀJ + λI so we can reuse it inside the Newton loop. let lu_symbolic = Self::precompute_symbolic_lu(&jc.sym, &lambda_i)?; + let jt_sym = jc.sym.transpose().to_col_major()?; + let jtj_symbolic = Self::precompute_symbolic_jtj(&jt_sym, &jc.sym)?; + let jt_value_indices = Self::precompute_jt_value_indices(&jc.sym, &jt_sym); + let a_sym = Self::precompute_symbolic_a(&jtj_symbolic.0, &lambda_i)?; + let a_from_jtj_indices = Self::precompute_a_from_jtj_indices(&jtj_symbolic.0, &a_sym); + let lambda_diag_indices = Self::precompute_lambda_diag_indices(&a_sym); // All done. Ok(Self { @@ -221,9 +236,93 @@ impl<'c> Model<'c> { row1_scratch: Vec::with_capacity(NONZEROES_PER_ROW), lambda_i, lu_symbolic, + jt_sym, + jtj_symbolic, + jt_value_indices, + a_sym, + a_from_jtj_indices, + lambda_diag_indices, }) } + fn precompute_symbolic_jtj( + jt_sym: &SymbolicSparseColMat, + jc_sym: &SymbolicSparseColMat, + ) -> Result<(SymbolicSparseColMat, SparseMatMulInfo), NonLinearSystemError> { + let jtj_sym = faer::sparse::linalg::matmul::sparse_sparse_matmul_symbolic( + jt_sym.as_ref(), + jc_sym.as_ref(), + )?; + Ok(jtj_sym) + } + + fn precompute_jt_value_indices( + jc_sym: &SymbolicSparseColMat, + jt_sym: &SymbolicSparseColMat, + ) -> Vec { + let mut indices = Vec::with_capacity(jt_sym.compute_nnz()); + let jc_row_idx = jc_sym.row_idx(); + let jt_row_idx = jt_sym.row_idx(); + + for jt_col in 0..jt_sym.ncols() { + let jt_col_range = jt_sym.col_range(jt_col); + for jt_idx in jt_col_range.clone() { + let original_col = jt_row_idx[jt_idx]; + let original_row = jt_col; + let mut jc_col_range = jc_sym.col_range(original_col); + let jc_idx = jc_col_range + .find(|idx| jc_row_idx[*idx] == original_row) + .expect("transpose symbolic structure mismatch"); + indices.push(jc_idx); + } + } + + indices + } + + fn precompute_symbolic_a( + jtj_sym: &SymbolicSparseColMat, + lambda_i: &faer::sparse::SparseColMat, + ) -> Result, NonLinearSystemError> { + let ones = vec![1.0; jtj_sym.compute_nnz()]; + let jtj = SparseColMatRef::new(jtj_sym.as_ref(), &ones); + let a = jtj + lambda_i; + Ok(a.symbolic().to_owned()?) + } + + fn precompute_a_from_jtj_indices( + jtj_sym: &SymbolicSparseColMat, + a_sym: &SymbolicSparseColMat, + ) -> Vec { + let mut indices = Vec::with_capacity(jtj_sym.compute_nnz()); + let a_row_idx = a_sym.row_idx(); + for col in 0..jtj_sym.ncols() { + let jtj_col_range = jtj_sym.col_range(col); + for jtj_idx in jtj_col_range.clone() { + let row = jtj_sym.row_idx()[jtj_idx]; + let mut a_col_range = a_sym.col_range(col); + let a_idx = a_col_range + .find(|idx| a_row_idx[*idx] == row) + .expect("A symbolic must contain all J^T J entries"); + indices.push(a_idx); + } + } + indices + } + + fn precompute_lambda_diag_indices(a_sym: &SymbolicSparseColMat) -> Vec { + let mut diag = Vec::with_capacity(a_sym.ncols()); + let row_idx = a_sym.row_idx(); + for col in 0..a_sym.ncols() { + let mut col_range = a_sym.col_range(col); + let idx = col_range + .find(|idx| row_idx[*idx] == col) + .expect("diagonal must exist in JᵀJ + λI"); + diag.push(idx); + } + diag + } + /// This is used in the core Newton solving, but it can be calculated entirely from /// the symbolic structure of the constraints. So let's do it here, before running /// the newton loop, to keep that loop fast. diff --git a/kcl-ezpz/src/solver/newton.rs b/kcl-ezpz/src/solver/newton.rs index e5fafbb5..b05275d3 100644 --- a/kcl-ezpz/src/solver/newton.rs +++ b/kcl-ezpz/src/solver/newton.rs @@ -1,7 +1,12 @@ use faer::{ - ColRef, + Accum, ColRef, Par, + dyn_stack::{MemBuffer, MemStack}, + get_global_parallelism, prelude::Solve, - sparse::{SparseColMatRef, linalg::solvers::Lu}, + sparse::{ + SparseColMatMut, SparseColMatRef, + linalg::{matmul, solvers::Lu}, + }, }; use crate::{Config, NonLinearSystemError}; @@ -16,15 +21,30 @@ impl Model<'_> { config: Config, ) -> Result { let m = self.layout.total_num_residuals; - let mut global_residual = vec![0.0; m]; + // Preallocate scratch space for computing JᵀJ. + let jtj_nnz = self.jtj_symbolic.0.compute_nnz(); + let jtj_scratch_req = { + let (jtj_sym, _) = &self.jtj_symbolic; + matmul::sparse_sparse_matmul_numeric_scratch::(jtj_sym.as_ref(), Par::Seq) + }; + let mut jtj_vals = vec![0.0; jtj_nnz]; + let mut jtj_mem = MemBuffer::new(jtj_scratch_req); + let mut jt_vals = vec![0.0; self.jt_sym.compute_nnz()]; + let mut a_vals = vec![0.0; self.a_sym.compute_nnz()]; + let mut b_vals = vec![0.0; self.layout.num_variables]; + for this_iteration in 0..config.max_iterations { // Assemble global residual and Jacobian // Re-evaluate the global residual. self.residual(current_values, &mut global_residual); // Re-evaluate the global jacobian, write it into self.jc self.refresh_jacobian(current_values); + let (jtj_sym, jtj_info) = &self.jtj_symbolic; + for (jt_idx, jc_idx) in self.jt_value_indices.iter().copied().enumerate() { + jt_vals[jt_idx] = self.jc.vals[jc_idx]; + } // Convergence check: if the residual is within our tolerance, // then the system is totally solved and we can return. @@ -41,15 +61,42 @@ impl Model<'_> { (JᵀJ + λI) d = -Jᵀr This involves creating a matrix A and rhs b where A = JᵀJ + λI - b = -Jᵀr + b = -Jᵀr */ let j = SparseColMatRef::new(self.jc.sym.as_ref(), &self.jc.vals); - // TODO: Is there any way to transpose `j` and keep it in column-major? - // Converting from row- to column-major might not be necessary. - let jtj = j.transpose().to_col_major()? * j; - let a = jtj + &self.lambda_i; - let b = j.transpose() * -ColRef::from_slice(&global_residual); + let jt = SparseColMatRef::new(self.jt_sym.as_ref(), &jt_vals); + + // Compute JᵀJ, reusing its symbolic structure. + let jtj_stack = MemStack::new(&mut jtj_mem); + matmul::sparse_sparse_matmul_numeric( + SparseColMatMut::new(jtj_sym.as_ref(), &mut jtj_vals), + Accum::Replace, + jt.as_ref(), + j, + 1.0, + jtj_info, + get_global_parallelism(), + jtj_stack, + ); + a_vals.fill(0.0); + for (jtj_idx, a_idx) in self.a_from_jtj_indices.iter().copied().enumerate() { + a_vals[a_idx] = jtj_vals[jtj_idx]; + } + for (col, diag_idx) in self.lambda_diag_indices.iter().copied().enumerate() { + a_vals[diag_idx] += self.lambda_i.val_of_col(col)[0]; + } + let a = SparseColMatRef::new(self.a_sym.as_ref(), &a_vals); + b_vals.fill(0.0); + let jt_row_idx = self.jt_sym.row_idx(); + for (col, residual_val) in global_residual.iter().enumerate().take(self.jt_sym.ncols()) + { + let col_range = self.jt_sym.col_range(col); + for idx in col_range.clone() { + b_vals[jt_row_idx[idx]] -= jt_vals[idx] * residual_val; + } + } + let b = ColRef::from_slice(&b_vals); // Solve linear system let factored = Lu::try_new_with_symbolic(self.lu_symbolic.clone(), a.as_ref())?;