Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion kcl-ezpz/src/solver.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Mutex;

use faer::sparse::{Pair, SparseColMatRef, SymbolicSparseColMat, linalg::solvers::SymbolicLu};
use faer::sparse::{
Pair, SparseColMatRef, SymbolicSparseColMat,
linalg::{matmul::SparseMatMulInfo, solvers::SymbolicLu},
};

use crate::{
Constraint, ConstraintEntry, NonLinearSystemError, Warning, WarningContent,
Expand Down Expand Up @@ -96,6 +99,12 @@ pub(crate) struct Model<'c> {
pub(crate) warnings: Mutex<Vec<Warning>>,
lambda_i: faer::sparse::SparseColMat<usize, f64>,
lu_symbolic: SymbolicLu<usize>,
jt_sym: SymbolicSparseColMat<usize>,
jtj_symbolic: (SymbolicSparseColMat<usize>, SparseMatMulInfo),
jt_value_indices: Vec<usize>,
a_sym: SymbolicSparseColMat<usize>,
a_from_jtj_indices: Vec<usize>,
lambda_diag_indices: Vec<usize>,
}

fn validate_variables(
Expand Down Expand Up @@ -210,6 +219,12 @@ impl<'c> Model<'c> {

// Precompute the symbolic LU of A = JᵀJ + λI so we can reuse it inside the Newton loop.
let lu_symbolic = Self::precompute_symbolic_lu(&jc.sym, &lambda_i)?;
let jt_sym = jc.sym.transpose().to_col_major()?;
let jtj_symbolic = Self::precompute_symbolic_jtj(&jt_sym, &jc.sym)?;
let jt_value_indices = Self::precompute_jt_value_indices(&jc.sym, &jt_sym);
let a_sym = Self::precompute_symbolic_a(&jtj_symbolic.0, &lambda_i)?;
let a_from_jtj_indices = Self::precompute_a_from_jtj_indices(&jtj_symbolic.0, &a_sym);
let lambda_diag_indices = Self::precompute_lambda_diag_indices(&a_sym);

// All done.
Ok(Self {
Expand All @@ -221,9 +236,93 @@ impl<'c> Model<'c> {
row1_scratch: Vec::with_capacity(NONZEROES_PER_ROW),
lambda_i,
lu_symbolic,
jt_sym,
jtj_symbolic,
jt_value_indices,
a_sym,
a_from_jtj_indices,
lambda_diag_indices,
})
}

fn precompute_symbolic_jtj(
jt_sym: &SymbolicSparseColMat<usize>,
jc_sym: &SymbolicSparseColMat<usize>,
) -> Result<(SymbolicSparseColMat<usize>, SparseMatMulInfo), NonLinearSystemError> {
let jtj_sym = faer::sparse::linalg::matmul::sparse_sparse_matmul_symbolic(
jt_sym.as_ref(),
jc_sym.as_ref(),
)?;
Ok(jtj_sym)
}

fn precompute_jt_value_indices(
jc_sym: &SymbolicSparseColMat<usize>,
jt_sym: &SymbolicSparseColMat<usize>,
) -> Vec<usize> {
let mut indices = Vec::with_capacity(jt_sym.compute_nnz());
let jc_row_idx = jc_sym.row_idx();
let jt_row_idx = jt_sym.row_idx();

for jt_col in 0..jt_sym.ncols() {
let jt_col_range = jt_sym.col_range(jt_col);
for jt_idx in jt_col_range.clone() {
let original_col = jt_row_idx[jt_idx];
let original_row = jt_col;
let mut jc_col_range = jc_sym.col_range(original_col);
let jc_idx = jc_col_range
.find(|idx| jc_row_idx[*idx] == original_row)
.expect("transpose symbolic structure mismatch");
indices.push(jc_idx);
}
}

indices
}

fn precompute_symbolic_a(
jtj_sym: &SymbolicSparseColMat<usize>,
lambda_i: &faer::sparse::SparseColMat<usize, f64>,
) -> Result<SymbolicSparseColMat<usize>, NonLinearSystemError> {
let ones = vec![1.0; jtj_sym.compute_nnz()];
let jtj = SparseColMatRef::new(jtj_sym.as_ref(), &ones);
let a = jtj + lambda_i;
Ok(a.symbolic().to_owned()?)
}

fn precompute_a_from_jtj_indices(
jtj_sym: &SymbolicSparseColMat<usize>,
a_sym: &SymbolicSparseColMat<usize>,
) -> Vec<usize> {
let mut indices = Vec::with_capacity(jtj_sym.compute_nnz());
let a_row_idx = a_sym.row_idx();
for col in 0..jtj_sym.ncols() {
let jtj_col_range = jtj_sym.col_range(col);
for jtj_idx in jtj_col_range.clone() {
let row = jtj_sym.row_idx()[jtj_idx];
let mut a_col_range = a_sym.col_range(col);
let a_idx = a_col_range
.find(|idx| a_row_idx[*idx] == row)
.expect("A symbolic must contain all J^T J entries");
indices.push(a_idx);
}
}
indices
}

fn precompute_lambda_diag_indices(a_sym: &SymbolicSparseColMat<usize>) -> Vec<usize> {
let mut diag = Vec::with_capacity(a_sym.ncols());
let row_idx = a_sym.row_idx();
for col in 0..a_sym.ncols() {
let mut col_range = a_sym.col_range(col);
let idx = col_range
.find(|idx| row_idx[*idx] == col)
.expect("diagonal must exist in JᵀJ + λI");
diag.push(idx);
}
diag
}

/// This is used in the core Newton solving, but it can be calculated entirely from
/// the symbolic structure of the constraints. So let's do it here, before running
/// the newton loop, to keep that loop fast.
Expand Down
65 changes: 56 additions & 9 deletions kcl-ezpz/src/solver/newton.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use faer::{
ColRef,
Accum, ColRef, Par,
dyn_stack::{MemBuffer, MemStack},
get_global_parallelism,
prelude::Solve,
sparse::{SparseColMatRef, linalg::solvers::Lu},
sparse::{
SparseColMatMut, SparseColMatRef,
linalg::{matmul, solvers::Lu},
},
};

use crate::{Config, NonLinearSystemError};
Expand All @@ -16,15 +21,30 @@ impl Model<'_> {
config: Config,
) -> Result<usize, NonLinearSystemError> {
let m = self.layout.total_num_residuals;

let mut global_residual = vec![0.0; m];

// Preallocate scratch space for computing JᵀJ.
let jtj_nnz = self.jtj_symbolic.0.compute_nnz();
let jtj_scratch_req = {
let (jtj_sym, _) = &self.jtj_symbolic;
matmul::sparse_sparse_matmul_numeric_scratch::<usize, f64>(jtj_sym.as_ref(), Par::Seq)
};
let mut jtj_vals = vec![0.0; jtj_nnz];
let mut jtj_mem = MemBuffer::new(jtj_scratch_req);
let mut jt_vals = vec![0.0; self.jt_sym.compute_nnz()];
let mut a_vals = vec![0.0; self.a_sym.compute_nnz()];
let mut b_vals = vec![0.0; self.layout.num_variables];

for this_iteration in 0..config.max_iterations {
// Assemble global residual and Jacobian
// Re-evaluate the global residual.
self.residual(current_values, &mut global_residual);
// Re-evaluate the global jacobian, write it into self.jc
self.refresh_jacobian(current_values);
let (jtj_sym, jtj_info) = &self.jtj_symbolic;
for (jt_idx, jc_idx) in self.jt_value_indices.iter().copied().enumerate() {
jt_vals[jt_idx] = self.jc.vals[jc_idx];
}

// Convergence check: if the residual is within our tolerance,
// then the system is totally solved and we can return.
Expand All @@ -41,15 +61,42 @@ impl Model<'_> {
(JᵀJ + λI) d = -Jᵀr
This involves creating a matrix A and rhs b where
A = JᵀJ + λI
b = -Jᵀr
b = -Jᵀr
*/

let j = SparseColMatRef::new(self.jc.sym.as_ref(), &self.jc.vals);
// TODO: Is there any way to transpose `j` and keep it in column-major?
// Converting from row- to column-major might not be necessary.
let jtj = j.transpose().to_col_major()? * j;
let a = jtj + &self.lambda_i;
let b = j.transpose() * -ColRef::from_slice(&global_residual);
let jt = SparseColMatRef::new(self.jt_sym.as_ref(), &jt_vals);

// Compute JᵀJ, reusing its symbolic structure.
let jtj_stack = MemStack::new(&mut jtj_mem);
matmul::sparse_sparse_matmul_numeric(
SparseColMatMut::new(jtj_sym.as_ref(), &mut jtj_vals),
Accum::Replace,
jt.as_ref(),
j,
1.0,
jtj_info,
get_global_parallelism(),
jtj_stack,
);
a_vals.fill(0.0);
for (jtj_idx, a_idx) in self.a_from_jtj_indices.iter().copied().enumerate() {
a_vals[a_idx] = jtj_vals[jtj_idx];
}
for (col, diag_idx) in self.lambda_diag_indices.iter().copied().enumerate() {
a_vals[diag_idx] += self.lambda_i.val_of_col(col)[0];
}
let a = SparseColMatRef::new(self.a_sym.as_ref(), &a_vals);
b_vals.fill(0.0);
let jt_row_idx = self.jt_sym.row_idx();
for (col, residual_val) in global_residual.iter().enumerate().take(self.jt_sym.ncols())
{
let col_range = self.jt_sym.col_range(col);
for idx in col_range.clone() {
b_vals[jt_row_idx[idx]] -= jt_vals[idx] * residual_val;
}
}
let b = ColRef::from_slice(&b_vals);

// Solve linear system
let factored = Lu::try_new_with_symbolic(self.lu_symbolic.clone(), a.as_ref())?;
Expand Down