Skip to content

Commit bbda039

Browse files
committed
Betteragain
1 parent 182ffe1 commit bbda039

File tree

2 files changed

+21
-36
lines changed

2 files changed

+21
-36
lines changed

kcl-ezpz/src/solver.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ pub(crate) struct Model<'c> {
101101
lu_symbolic: SymbolicLu<usize>,
102102
jt_sym: SymbolicSparseColMat<usize>,
103103
jtj_symbolic: (SymbolicSparseColMat<usize>, SparseMatMulInfo),
104-
a_sym: SymbolicSparseColMat<usize>,
105104
jt_value_indices: Vec<usize>,
105+
lambda_diag_indices: Vec<usize>,
106106
}
107107

108108
fn validate_variables(
@@ -219,8 +219,8 @@ impl<'c> Model<'c> {
219219
let lu_symbolic = Self::precompute_symbolic_lu(&jc.sym, &lambda_i)?;
220220
let jt_sym = jc.sym.transpose().to_col_major()?;
221221
let jtj_symbolic = Self::precompute_symbolic_jtj(&jt_sym, &jc.sym)?;
222-
let a_sym = Self::precompute_symbolic_a(&jtj_symbolic.0, &lambda_i)?;
223222
let jt_value_indices = Self::precompute_jt_value_indices(&jc.sym, &jt_sym);
223+
let lambda_diag_indices = Self::precompute_lambda_diag_indices(&jtj_symbolic.0);
224224

225225
// All done.
226226
Ok(Self {
@@ -234,8 +234,8 @@ impl<'c> Model<'c> {
234234
lu_symbolic,
235235
jt_sym,
236236
jtj_symbolic,
237-
a_sym,
238237
jt_value_indices,
238+
lambda_diag_indices,
239239
})
240240
}
241241

@@ -250,17 +250,6 @@ impl<'c> Model<'c> {
250250
Ok(jtj_sym)
251251
}
252252

253-
fn precompute_symbolic_a(
254-
jtj_sym: &SymbolicSparseColMat<usize>,
255-
lambda_i: &faer::sparse::SparseColMat<usize, f64>,
256-
) -> Result<SymbolicSparseColMat<usize>, NonLinearSystemError> {
257-
// Any non-zero values will do; we only care about the sparsity pattern of JᵀJ + λI.
258-
let ones = vec![1.0; jtj_sym.compute_nnz()];
259-
let jtj = SparseColMatRef::new(jtj_sym.as_ref(), &ones);
260-
let a = jtj + lambda_i;
261-
Ok(a.symbolic().to_owned()?)
262-
}
263-
264253
fn precompute_jt_value_indices(
265254
jc_sym: &SymbolicSparseColMat<usize>,
266255
jt_sym: &SymbolicSparseColMat<usize>,
@@ -285,6 +274,19 @@ impl<'c> Model<'c> {
285274
indices
286275
}
287276

277+
fn precompute_lambda_diag_indices(jtj_sym: &SymbolicSparseColMat<usize>) -> Vec<usize> {
278+
let mut diag = Vec::with_capacity(jtj_sym.ncols());
279+
let row_idx = jtj_sym.row_idx();
280+
for col in 0..jtj_sym.ncols() {
281+
let mut col_range = jtj_sym.col_range(col);
282+
let idx = col_range
283+
.find(|idx| row_idx[*idx] == col)
284+
.expect("diagonal must exist in J^T J");
285+
diag.push(idx);
286+
}
287+
diag
288+
}
289+
288290
/// This is used in the core Newton solving, but it can be calculated entirely from
289291
/// the symbolic structure of the constraints. So let's do it here, before running
290292
/// the newton loop, to keep that loop fast.

kcl-ezpz/src/solver/newton.rs

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use faer::{
66
sparse::{
77
SparseColMatMut, SparseColMatRef,
88
linalg::{matmul, solvers::Lu},
9-
ops,
109
},
1110
};
1211

@@ -32,7 +31,6 @@ impl Model<'_> {
3231
};
3332
let mut jtj_vals = vec![0.0; jtj_nnz];
3433
let mut jtj_mem = MemBuffer::new(jtj_scratch_req);
35-
let mut a_vals = vec![0.0; self.a_sym.compute_nnz()];
3634
let mut jt_vals = vec![0.0; self.jt_sym.compute_nnz()];
3735

3836
for this_iteration in 0..config.max_iterations {
@@ -79,26 +77,11 @@ impl Model<'_> {
7977
get_global_parallelism(),
8078
jtj_stack,
8179
);
82-
let jtj = SparseColMatRef::new(jtj_sym.as_ref(), &jtj_vals);
83-
84-
a_vals.fill(0.0);
85-
ops::binary_op_assign_into(
86-
SparseColMatMut::new(self.a_sym.as_ref(), &mut a_vals),
87-
jtj,
88-
|dst, src| {
89-
*dst = *src.unwrap_or(&0.0);
90-
},
91-
);
92-
ops::binary_op_assign_into(
93-
SparseColMatMut::new(self.a_sym.as_ref(), &mut a_vals),
94-
self.lambda_i.as_ref(),
95-
|dst, src| {
96-
if let Some(val) = src {
97-
*dst += *val;
98-
}
99-
},
100-
);
101-
let a = SparseColMatRef::new(self.a_sym.as_ref(), &a_vals);
80+
let mut a_vals = jtj_vals.clone();
81+
for (col, diag_idx) in self.lambda_diag_indices.iter().copied().enumerate() {
82+
a_vals[diag_idx] += self.lambda_i.val_of_col(col)[0];
83+
}
84+
let a = SparseColMatRef::new(jtj_sym.as_ref(), &a_vals);
10285
let b = j.transpose() * -ColRef::from_slice(&global_residual);
10386

10487
// Solve linear system

0 commit comments

Comments
 (0)