Skip to content

Commit 57159b7

Browse files
authored
Add batched tridiagonal solver with Kokkos parallelization (#162)
1 parent 1ba7fb1 commit 57159b7

File tree

3 files changed

+598
-0
lines changed

3 files changed

+598
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
#pragma once
2+
3+
#include <Kokkos_Core.hpp>
4+
5+
#include "../vector.h"
6+
#include "../vector_operations.h"
7+
8+
template <typename T>
9+
class BatchedTridiagonalSolver
10+
{
11+
public:
12+
BatchedTridiagonalSolver(int matrix_dimension, int batch_count, bool is_cyclic = true)
13+
: matrix_dimension_(matrix_dimension)
14+
, batch_count_(batch_count)
15+
, main_diagonal_("BatchedTridiagonalSolver::main_diagonal", matrix_dimension * batch_count)
16+
, sub_diagonal_("BatchedTridiagonalSolver::sub_diagonal", matrix_dimension * batch_count)
17+
, buffer_("BatchedTridiagonalSolver::buffer", is_cyclic ? matrix_dimension * batch_count : 0)
18+
, gamma_("BatchedTridiagonalSolver::gamma", is_cyclic ? batch_count : 0)
19+
, is_cyclic_(is_cyclic)
20+
, is_factorized_(false)
21+
{
22+
assign(main_diagonal_, T(0));
23+
assign(sub_diagonal_, T(0));
24+
}
25+
26+
/* ---------------------------- */
27+
/* Accessors for matrix entries */
28+
/* ---------------------------- */
29+
30+
KOKKOS_INLINE_FUNCTION
31+
const T& main_diagonal(const int batch_idx, const int index) const
32+
{
33+
return main_diagonal_(batch_idx * matrix_dimension_ + index);
34+
}
35+
KOKKOS_INLINE_FUNCTION
36+
T& main_diagonal(const int batch_idx, const int index)
37+
{
38+
return main_diagonal_(batch_idx * matrix_dimension_ + index);
39+
}
40+
41+
KOKKOS_INLINE_FUNCTION
42+
const T& sub_diagonal(const int batch_idx, const int index) const
43+
{
44+
return sub_diagonal_(batch_idx * matrix_dimension_ + index);
45+
}
46+
KOKKOS_INLINE_FUNCTION
47+
T& sub_diagonal(const int batch_idx, const int index)
48+
{
49+
return sub_diagonal_(batch_idx * matrix_dimension_ + index);
50+
}
51+
52+
KOKKOS_INLINE_FUNCTION
53+
const T& cyclic_corner(const int batch_idx) const
54+
{
55+
return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1));
56+
}
57+
58+
KOKKOS_INLINE_FUNCTION
59+
T& cyclic_corner(const int batch_idx)
60+
{
61+
return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1));
62+
}
63+
64+
/* ---------------------------------------------- */
65+
/* Setup: Cholesky Decomposition: A = L * D * L^T */
66+
/* ---------------------------------------------- */
67+
// This step factorizes the tridiagonal matrix into lower triangular (L) and diagonal (D) matrices.
68+
// For cyclic systems, it also applies the Shermann-Morrison adjustment to account for the cyclic connection.
69+
70+
void setup()
71+
{
72+
// Create local copies for lambda capture
73+
int matrix_dimension = matrix_dimension_;
74+
Vector<T> main_diagonal = main_diagonal_;
75+
Vector<T> sub_diagonal = sub_diagonal_;
76+
Vector<T> gamma = gamma_;
77+
78+
if (!is_cyclic_) {
79+
Kokkos::parallel_for(
80+
"SetupNonCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) {
81+
// ----------------------------------- //
82+
// Obtain offset for the current batch //
83+
int offset = batch_idx * matrix_dimension;
84+
85+
// ---------------------- //
86+
// Cholesky Decomposition //
87+
for (int i = 1; i < matrix_dimension; i++) {
88+
sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1);
89+
const T factor = sub_diagonal(offset + i - 1);
90+
main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1);
91+
}
92+
});
93+
}
94+
else {
95+
Kokkos::parallel_for(
96+
"SetupCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) {
97+
// ----------------------------------- //
98+
// Obtain offset for the current batch //
99+
int offset = batch_idx * matrix_dimension;
100+
101+
// ------------------------------------------------- //
102+
// Shermann-Morrison Adjustment //
103+
// - Modify the first and last main diagonal element //
104+
// - Compute and store gamma for later use //
105+
// ------------------------------------------------- //
106+
T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1);
107+
gamma(batch_idx) = -main_diagonal(offset + 0);
108+
main_diagonal(offset + 0) -= gamma(batch_idx);
109+
main_diagonal(offset + matrix_dimension - 1) -=
110+
cyclic_corner_element * cyclic_corner_element / gamma(batch_idx);
111+
112+
// ---------------------- //
113+
// Cholesky Decomposition //
114+
for (int i = 1; i < matrix_dimension; i++) {
115+
sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1);
116+
const T factor = sub_diagonal(offset + i - 1);
117+
main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1);
118+
}
119+
});
120+
}
121+
Kokkos::fence();
122+
is_factorized_ = true;
123+
}
124+
125+
/* ---------------------------------------- */
126+
/* Solve: Forward and Backward Substitution */
127+
/* ---------------------------------------- */
128+
// This step solves the system Ax = b using the factorized form of A.
129+
// For cyclic systems, it also performs the Shermann-Morrison reconstruction to obtain the final solution.
130+
131+
void solve(Vector<T> rhs, int batch_offset = 0, int batch_stride = 1)
132+
{
133+
if (!is_factorized_) {
134+
throw std::runtime_error("Error: Matrix must be factorized before solving.");
135+
}
136+
137+
// Compute the effective number of batches to solve
138+
int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride;
139+
140+
// Create local copies for lambda capture
141+
int matrix_dimension = matrix_dimension_;
142+
Vector<T> main_diagonal = main_diagonal_;
143+
Vector<T> sub_diagonal = sub_diagonal_;
144+
Vector<T> buffer = buffer_;
145+
Vector<T> gamma = gamma_;
146+
147+
if (!is_cyclic_) {
148+
Kokkos::parallel_for(
149+
"SolveNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) {
150+
// ----------------------------------- //
151+
// Obtain offset for the current batch //
152+
int batch_idx = batch_stride * k + batch_offset;
153+
int offset = batch_idx * matrix_dimension;
154+
155+
// -------------------- //
156+
// Forward Substitution //
157+
for (int i = 1; i < matrix_dimension; i++) {
158+
rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1);
159+
}
160+
161+
// ---------------- //
162+
// Diagonal Scaling //
163+
for (int i = 0; i < matrix_dimension; i++) {
164+
rhs(offset + i) /= main_diagonal(offset + i);
165+
}
166+
167+
// --------------------- //
168+
// Backward Substitution //
169+
for (int i = matrix_dimension - 2; i >= 0; i--) {
170+
rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1);
171+
}
172+
});
173+
}
174+
else {
175+
Kokkos::parallel_for(
176+
"SolveCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) {
177+
// ----------------------------------- //
178+
// Obtain offset for the current batch //
179+
int batch_idx = batch_stride * k + batch_offset;
180+
int offset = batch_idx * matrix_dimension;
181+
182+
// -------------------- //
183+
// Forward Substitution //
184+
T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1);
185+
buffer(offset + 0) = gamma(batch_idx);
186+
for (int i = 1; i < matrix_dimension; i++) {
187+
rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1);
188+
if (i < matrix_dimension - 1)
189+
buffer(offset + i) = 0.0 - sub_diagonal(offset + i - 1) * buffer(offset + i - 1);
190+
else
191+
buffer(offset + i) =
192+
cyclic_corner_element - sub_diagonal(offset + i - 1) * buffer(offset + i - 1);
193+
}
194+
195+
// ---------------- //
196+
// Diagonal Scaling //
197+
for (int i = 0; i < matrix_dimension; i++) {
198+
rhs(offset + i) /= main_diagonal(offset + i);
199+
buffer(offset + i) /= main_diagonal(offset + i);
200+
}
201+
202+
// --------------------- //
203+
// Backward Substitution //
204+
for (int i = matrix_dimension - 2; i >= 0; i--) {
205+
rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1);
206+
buffer(offset + i) -= sub_diagonal(offset + i) * buffer(offset + i + 1);
207+
}
208+
209+
// ------------------------------- //
210+
// Shermann-Morrison Reonstruction //
211+
const T dot_product_x_v =
212+
rhs(offset + 0) + cyclic_corner_element / gamma(batch_idx) * rhs(offset + matrix_dimension - 1);
213+
const T dot_product_u_v = buffer(offset + 0) + cyclic_corner_element / gamma(batch_idx) *
214+
buffer(offset + matrix_dimension - 1);
215+
const T factor = dot_product_x_v / (1.0 + dot_product_u_v);
216+
217+
for (int i = 0; i < matrix_dimension; i++) {
218+
rhs(offset + i) -= factor * buffer(offset + i);
219+
}
220+
});
221+
}
222+
Kokkos::fence();
223+
}
224+
225+
/* ---------------------------- */
226+
/* Solve: Diagonal Scaling Only */
227+
/* ---------------------------- */
228+
// This step performs only the diagonal scaling part of the solve process.
229+
// It is useful when the matrix has a non-zero diagonal but zero off-diagonal entries.
230+
// Note that .setup() modifies main_diagonal(0) in the cyclic case.
231+
232+
void solve_diagonal(Vector<T> rhs, int batch_offset = 0, int batch_stride = 1)
233+
{
234+
if (!is_factorized_) {
235+
throw std::runtime_error("Error: Matrix must be factorized before solving.");
236+
}
237+
238+
// Compute the effective number of batches to solve
239+
int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride;
240+
241+
// Create local copies for lambda capture
242+
int matrix_dimension = matrix_dimension_;
243+
Vector<T> main_diagonal = main_diagonal_;
244+
Vector<T> gamma = gamma_;
245+
246+
if (!is_cyclic_) {
247+
Kokkos::parallel_for(
248+
"SolveDiagonalNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) {
249+
// ----------------------------------- //
250+
// Obtain offset for the current batch //
251+
int batch_idx = batch_stride * k + batch_offset;
252+
int offset = batch_idx * matrix_dimension;
253+
254+
// ---------------- //
255+
// Diagonal Scaling //
256+
for (int i = 0; i < matrix_dimension; i++) {
257+
rhs(offset + i) /= main_diagonal(offset + i);
258+
}
259+
});
260+
}
261+
else {
262+
Kokkos::parallel_for(
263+
"SolveDiagonalCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) {
264+
// ----------------------------------- //
265+
// Obtain offset for the current batch //
266+
int batch_idx = batch_stride * k + batch_offset;
267+
int offset = batch_idx * matrix_dimension;
268+
269+
// ---------------- //
270+
// Diagonal Scaling //
271+
rhs(offset + 0) /= main_diagonal(offset + 0) + gamma(batch_idx);
272+
for (int i = 1; i < matrix_dimension; i++) {
273+
rhs(offset + i) /= main_diagonal(offset + i);
274+
}
275+
});
276+
}
277+
Kokkos::fence();
278+
}
279+
280+
private:
281+
int matrix_dimension_;
282+
int batch_count_;
283+
284+
Vector<T> main_diagonal_;
285+
Vector<T> sub_diagonal_;
286+
Vector<T> buffer_;
287+
Vector<T> gamma_;
288+
289+
bool is_cyclic_;
290+
bool is_factorized_;
291+
};

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_executable(gmgpolar_tests
1414
LinearAlgebra/csr_solver.cpp
1515
LinearAlgebra/tridiagonal_solver.cpp
1616
LinearAlgebra/cyclic_tridiagonal_solver.cpp
17+
LinearAlgebra/Solvers/tridiagonal_solver.cpp
1718
PolarGrid/polargrid.cpp
1819
Interpolation/prolongation.cpp
1920
Interpolation/restriction.cpp

0 commit comments

Comments
 (0)