|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <Kokkos_Core.hpp> |
| 4 | + |
| 5 | +#include "../vector.h" |
| 6 | +#include "../vector_operations.h" |
| 7 | + |
| 8 | +template <typename T> |
| 9 | +class BatchedTridiagonalSolver |
| 10 | +{ |
| 11 | +public: |
| 12 | + BatchedTridiagonalSolver(int matrix_dimension, int batch_count, bool is_cyclic = true) |
| 13 | + : matrix_dimension_(matrix_dimension) |
| 14 | + , batch_count_(batch_count) |
| 15 | + , main_diagonal_("BatchedTridiagonalSolver::main_diagonal", matrix_dimension * batch_count) |
| 16 | + , sub_diagonal_("BatchedTridiagonalSolver::sub_diagonal", matrix_dimension * batch_count) |
| 17 | + , buffer_("BatchedTridiagonalSolver::buffer", is_cyclic ? matrix_dimension * batch_count : 0) |
| 18 | + , gamma_("BatchedTridiagonalSolver::gamma", is_cyclic ? batch_count : 0) |
| 19 | + , is_cyclic_(is_cyclic) |
| 20 | + , is_factorized_(false) |
| 21 | + { |
| 22 | + assign(main_diagonal_, T(0)); |
| 23 | + assign(sub_diagonal_, T(0)); |
| 24 | + } |
| 25 | + |
| 26 | + /* ---------------------------- */ |
| 27 | + /* Accessors for matrix entries */ |
| 28 | + /* ---------------------------- */ |
| 29 | + |
| 30 | + KOKKOS_INLINE_FUNCTION |
| 31 | + const T& main_diagonal(const int batch_idx, const int index) const |
| 32 | + { |
| 33 | + return main_diagonal_(batch_idx * matrix_dimension_ + index); |
| 34 | + } |
| 35 | + KOKKOS_INLINE_FUNCTION |
| 36 | + T& main_diagonal(const int batch_idx, const int index) |
| 37 | + { |
| 38 | + return main_diagonal_(batch_idx * matrix_dimension_ + index); |
| 39 | + } |
| 40 | + |
| 41 | + KOKKOS_INLINE_FUNCTION |
| 42 | + const T& sub_diagonal(const int batch_idx, const int index) const |
| 43 | + { |
| 44 | + return sub_diagonal_(batch_idx * matrix_dimension_ + index); |
| 45 | + } |
| 46 | + KOKKOS_INLINE_FUNCTION |
| 47 | + T& sub_diagonal(const int batch_idx, const int index) |
| 48 | + { |
| 49 | + return sub_diagonal_(batch_idx * matrix_dimension_ + index); |
| 50 | + } |
| 51 | + |
| 52 | + KOKKOS_INLINE_FUNCTION |
| 53 | + const T& cyclic_corner(const int batch_idx) const |
| 54 | + { |
| 55 | + return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1)); |
| 56 | + } |
| 57 | + |
| 58 | + KOKKOS_INLINE_FUNCTION |
| 59 | + T& cyclic_corner(const int batch_idx) |
| 60 | + { |
| 61 | + return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1)); |
| 62 | + } |
| 63 | + |
| 64 | + /* ---------------------------------------------- */ |
| 65 | + /* Setup: Cholesky Decomposition: A = L * D * L^T */ |
| 66 | + /* ---------------------------------------------- */ |
| 67 | + // This step factorizes the tridiagonal matrix into lower triangular (L) and diagonal (D) matrices. |
| 68 | + // For cyclic systems, it also applies the Shermann-Morrison adjustment to account for the cyclic connection. |
| 69 | + |
| 70 | + void setup() |
| 71 | + { |
| 72 | + // Create local copies for lambda capture |
| 73 | + int matrix_dimension = matrix_dimension_; |
| 74 | + Vector<T> main_diagonal = main_diagonal_; |
| 75 | + Vector<T> sub_diagonal = sub_diagonal_; |
| 76 | + Vector<T> gamma = gamma_; |
| 77 | + |
| 78 | + if (!is_cyclic_) { |
| 79 | + Kokkos::parallel_for( |
| 80 | + "SetupNonCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) { |
| 81 | + // ----------------------------------- // |
| 82 | + // Obtain offset for the current batch // |
| 83 | + int offset = batch_idx * matrix_dimension; |
| 84 | + |
| 85 | + // ---------------------- // |
| 86 | + // Cholesky Decomposition // |
| 87 | + for (int i = 1; i < matrix_dimension; i++) { |
| 88 | + sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1); |
| 89 | + const T factor = sub_diagonal(offset + i - 1); |
| 90 | + main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1); |
| 91 | + } |
| 92 | + }); |
| 93 | + } |
| 94 | + else { |
| 95 | + Kokkos::parallel_for( |
| 96 | + "SetupCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) { |
| 97 | + // ----------------------------------- // |
| 98 | + // Obtain offset for the current batch // |
| 99 | + int offset = batch_idx * matrix_dimension; |
| 100 | + |
| 101 | + // ------------------------------------------------- // |
| 102 | + // Shermann-Morrison Adjustment // |
| 103 | + // - Modify the first and last main diagonal element // |
| 104 | + // - Compute and store gamma for later use // |
| 105 | + // ------------------------------------------------- // |
| 106 | + T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1); |
| 107 | + gamma(batch_idx) = -main_diagonal(offset + 0); |
| 108 | + main_diagonal(offset + 0) -= gamma(batch_idx); |
| 109 | + main_diagonal(offset + matrix_dimension - 1) -= |
| 110 | + cyclic_corner_element * cyclic_corner_element / gamma(batch_idx); |
| 111 | + |
| 112 | + // ---------------------- // |
| 113 | + // Cholesky Decomposition // |
| 114 | + for (int i = 1; i < matrix_dimension; i++) { |
| 115 | + sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1); |
| 116 | + const T factor = sub_diagonal(offset + i - 1); |
| 117 | + main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1); |
| 118 | + } |
| 119 | + }); |
| 120 | + } |
| 121 | + Kokkos::fence(); |
| 122 | + is_factorized_ = true; |
| 123 | + } |
| 124 | + |
| 125 | + /* ---------------------------------------- */ |
| 126 | + /* Solve: Forward and Backward Substitution */ |
| 127 | + /* ---------------------------------------- */ |
| 128 | + // This step solves the system Ax = b using the factorized form of A. |
| 129 | + // For cyclic systems, it also performs the Shermann-Morrison reconstruction to obtain the final solution. |
| 130 | + |
| 131 | + void solve(Vector<T> rhs, int batch_offset = 0, int batch_stride = 1) |
| 132 | + { |
| 133 | + if (!is_factorized_) { |
| 134 | + throw std::runtime_error("Error: Matrix must be factorized before solving."); |
| 135 | + } |
| 136 | + |
| 137 | + // Compute the effective number of batches to solve |
| 138 | + int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; |
| 139 | + |
| 140 | + // Create local copies for lambda capture |
| 141 | + int matrix_dimension = matrix_dimension_; |
| 142 | + Vector<T> main_diagonal = main_diagonal_; |
| 143 | + Vector<T> sub_diagonal = sub_diagonal_; |
| 144 | + Vector<T> buffer = buffer_; |
| 145 | + Vector<T> gamma = gamma_; |
| 146 | + |
| 147 | + if (!is_cyclic_) { |
| 148 | + Kokkos::parallel_for( |
| 149 | + "SolveNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { |
| 150 | + // ----------------------------------- // |
| 151 | + // Obtain offset for the current batch // |
| 152 | + int batch_idx = batch_stride * k + batch_offset; |
| 153 | + int offset = batch_idx * matrix_dimension; |
| 154 | + |
| 155 | + // -------------------- // |
| 156 | + // Forward Substitution // |
| 157 | + for (int i = 1; i < matrix_dimension; i++) { |
| 158 | + rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1); |
| 159 | + } |
| 160 | + |
| 161 | + // ---------------- // |
| 162 | + // Diagonal Scaling // |
| 163 | + for (int i = 0; i < matrix_dimension; i++) { |
| 164 | + rhs(offset + i) /= main_diagonal(offset + i); |
| 165 | + } |
| 166 | + |
| 167 | + // --------------------- // |
| 168 | + // Backward Substitution // |
| 169 | + for (int i = matrix_dimension - 2; i >= 0; i--) { |
| 170 | + rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1); |
| 171 | + } |
| 172 | + }); |
| 173 | + } |
| 174 | + else { |
| 175 | + Kokkos::parallel_for( |
| 176 | + "SolveCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { |
| 177 | + // ----------------------------------- // |
| 178 | + // Obtain offset for the current batch // |
| 179 | + int batch_idx = batch_stride * k + batch_offset; |
| 180 | + int offset = batch_idx * matrix_dimension; |
| 181 | + |
| 182 | + // -------------------- // |
| 183 | + // Forward Substitution // |
| 184 | + T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1); |
| 185 | + buffer(offset + 0) = gamma(batch_idx); |
| 186 | + for (int i = 1; i < matrix_dimension; i++) { |
| 187 | + rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1); |
| 188 | + if (i < matrix_dimension - 1) |
| 189 | + buffer(offset + i) = 0.0 - sub_diagonal(offset + i - 1) * buffer(offset + i - 1); |
| 190 | + else |
| 191 | + buffer(offset + i) = |
| 192 | + cyclic_corner_element - sub_diagonal(offset + i - 1) * buffer(offset + i - 1); |
| 193 | + } |
| 194 | + |
| 195 | + // ---------------- // |
| 196 | + // Diagonal Scaling // |
| 197 | + for (int i = 0; i < matrix_dimension; i++) { |
| 198 | + rhs(offset + i) /= main_diagonal(offset + i); |
| 199 | + buffer(offset + i) /= main_diagonal(offset + i); |
| 200 | + } |
| 201 | + |
| 202 | + // --------------------- // |
| 203 | + // Backward Substitution // |
| 204 | + for (int i = matrix_dimension - 2; i >= 0; i--) { |
| 205 | + rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1); |
| 206 | + buffer(offset + i) -= sub_diagonal(offset + i) * buffer(offset + i + 1); |
| 207 | + } |
| 208 | + |
| 209 | + // ------------------------------- // |
| 210 | + // Shermann-Morrison Reonstruction // |
| 211 | + const T dot_product_x_v = |
| 212 | + rhs(offset + 0) + cyclic_corner_element / gamma(batch_idx) * rhs(offset + matrix_dimension - 1); |
| 213 | + const T dot_product_u_v = buffer(offset + 0) + cyclic_corner_element / gamma(batch_idx) * |
| 214 | + buffer(offset + matrix_dimension - 1); |
| 215 | + const T factor = dot_product_x_v / (1.0 + dot_product_u_v); |
| 216 | + |
| 217 | + for (int i = 0; i < matrix_dimension; i++) { |
| 218 | + rhs(offset + i) -= factor * buffer(offset + i); |
| 219 | + } |
| 220 | + }); |
| 221 | + } |
| 222 | + Kokkos::fence(); |
| 223 | + } |
| 224 | + |
| 225 | + /* ---------------------------- */ |
| 226 | + /* Solve: Diagonal Scaling Only */ |
| 227 | + /* ---------------------------- */ |
| 228 | + // This step performs only the diagonal scaling part of the solve process. |
| 229 | + // It is useful when the matrix has a non-zero diagonal but zero off-diagonal entries. |
| 230 | + // Note that .setup() modifies main_diagonal(0) in the cyclic case. |
| 231 | + |
| 232 | + void solve_diagonal(Vector<T> rhs, int batch_offset = 0, int batch_stride = 1) |
| 233 | + { |
| 234 | + if (!is_factorized_) { |
| 235 | + throw std::runtime_error("Error: Matrix must be factorized before solving."); |
| 236 | + } |
| 237 | + |
| 238 | + // Compute the effective number of batches to solve |
| 239 | + int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; |
| 240 | + |
| 241 | + // Create local copies for lambda capture |
| 242 | + int matrix_dimension = matrix_dimension_; |
| 243 | + Vector<T> main_diagonal = main_diagonal_; |
| 244 | + Vector<T> gamma = gamma_; |
| 245 | + |
| 246 | + if (!is_cyclic_) { |
| 247 | + Kokkos::parallel_for( |
| 248 | + "SolveDiagonalNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { |
| 249 | + // ----------------------------------- // |
| 250 | + // Obtain offset for the current batch // |
| 251 | + int batch_idx = batch_stride * k + batch_offset; |
| 252 | + int offset = batch_idx * matrix_dimension; |
| 253 | + |
| 254 | + // ---------------- // |
| 255 | + // Diagonal Scaling // |
| 256 | + for (int i = 0; i < matrix_dimension; i++) { |
| 257 | + rhs(offset + i) /= main_diagonal(offset + i); |
| 258 | + } |
| 259 | + }); |
| 260 | + } |
| 261 | + else { |
| 262 | + Kokkos::parallel_for( |
| 263 | + "SolveDiagonalCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { |
| 264 | + // ----------------------------------- // |
| 265 | + // Obtain offset for the current batch // |
| 266 | + int batch_idx = batch_stride * k + batch_offset; |
| 267 | + int offset = batch_idx * matrix_dimension; |
| 268 | + |
| 269 | + // ---------------- // |
| 270 | + // Diagonal Scaling // |
| 271 | + rhs(offset + 0) /= main_diagonal(offset + 0) + gamma(batch_idx); |
| 272 | + for (int i = 1; i < matrix_dimension; i++) { |
| 273 | + rhs(offset + i) /= main_diagonal(offset + i); |
| 274 | + } |
| 275 | + }); |
| 276 | + } |
| 277 | + Kokkos::fence(); |
| 278 | + } |
| 279 | + |
| 280 | +private: |
| 281 | + int matrix_dimension_; |
| 282 | + int batch_count_; |
| 283 | + |
| 284 | + Vector<T> main_diagonal_; |
| 285 | + Vector<T> sub_diagonal_; |
| 286 | + Vector<T> buffer_; |
| 287 | + Vector<T> gamma_; |
| 288 | + |
| 289 | + bool is_cyclic_; |
| 290 | + bool is_factorized_; |
| 291 | +}; |
0 commit comments