@@ -64,31 +64,109 @@ namespace alfi::util::linalg {
6464 return X;
6565 }
6666
67+ /* *
68+ @brief Solves a tridiagonal system of linear equations.
69+
70+ This function solves a system of linear equations \f(AX = B\f), where \f(A\f) is a tridiagonal matrix.
71+ The input vectors are modified in place.
72+
73+ Unlike \ref tridiag_solve, this algorithm is generally unstable.\n
74+ The sufficient conditions for this algorithm to be stable are
75+ (see this article: https://stsynkov.math.ncsu.edu/book_sample_material/Sections_5.4-5.5.pdf):
76+ - \f(\abs{diag[0]} \ge \abs{upper[0]}\f),
77+ - \f(\abs{diag[k]} \ge \abs{lower[k]} + \abs{upper[k]}, k = 1, 2, ..., n - 2\f),
78+ - \f(\abs{diag[n-1]} \ge \abs{lower[n-1]}\f),
79+ - one of \f(n\f) inequalities is strict, i.e., \f(>\f) rather than \f(\ge\f).
80+
81+ The first element of `lower` and the last element of `upper` are ignored.
82+
83+ @param lower the subdiagonal elements of the tridiagonal matrix (first element is ignored)
84+ @param diag the diagonal elements of the tridiagonal matrix
85+ @param upper the superdiagonal elements of the tridiagonal matrix (last element is ignored)
86+ @param right the right-hand side vector of the system
87+ @return the solution vector
88+ */
89+ template <typename Number = DefaultNumber, template <typename > class Container = DefaultContainer>
90+ Container<Number> tridiag_solve_unstable (const auto & lower, auto && diag, const auto & upper, auto && right) {
91+ const auto n = right.size ();
92+ assert (n == lower.size ());
93+ assert (n == diag.size ());
94+ assert (n == upper.size ());
95+ if (n == 0 ) {
96+ return {};
97+ }
98+ for (SizeT i = 0 ; i < n - 1 ; ++i) {
99+ const auto m = lower[i+1 ] / diag[i];
100+ diag[i+1 ] -= m * upper[i];
101+ right[i+1 ] -= m * right[i];
102+ }
103+ Container<Number> X (n);
104+ X[n-1 ] = right[n-1 ] / diag[n-1 ];
105+ for (SizeT iter = 2 ; iter <= n; ++iter) {
106+ const auto i = n - iter;
107+ X[i] = (right[i] - upper[i] * X[i+1 ]) / diag[i];
108+ }
109+ return X;
110+ }
111+
112+ /* *
113+ @brief Solves a tridiagonal system of linear equations.
114+
115+ This function solves a system of linear equations \f(AX = B\f), where \f(A\f) is a tridiagonal matrix.
116+ The input vectors are modified in place.
117+
118+ Unlike \ref tridiag_solve_unstable, this algorithm is stable.
119+
120+ The first element of `lower` and the last element of `upper` are ignored.
121+
122+ @note Based on the https://github.com/snsinfu/cxx-spline/blob/625d4d325cb2/include/spline.hpp#L40-L105 code
123+ from https://github.com/snsinfu, which was licensed under the BSL-1.0 (Boost Software License 1.0).
124+
125+ @param lower the subdiagonal elements of the tridiagonal matrix (first element is ignored)
126+ @param diag the diagonal elements of the tridiagonal matrix
127+ @param upper the superdiagonal elements of the tridiagonal matrix (last element is ignored)
128+ @param right the right-hand side vector of the system
129+ @return the solution vector
130+ */
67131 template <typename Number = DefaultNumber, template <typename > class Container = DefaultContainer>
68- Container<Number> tridiag_solve (
69- Number a11, Number a12, Number r1,
70- Number ann1, Number ann, Number rn,
71- auto lower, auto diag, auto upper, auto right,
72- SizeT n
73- ) {
74- const auto m_1 = *lower / a11;
75- *diag -= m_1 * a12;
76- *right -= m_1 * r1;
77- ++lower, ++diag, ++upper, ++right;
78- for (SizeT i = 3 ; i < n; ++i, ++lower, ++diag, ++upper, ++right) {
79- const auto m = *lower / *(diag - 1 );
80- *diag -= m * *(upper - 1 );
81- *right -= m * *(right - 1 );
132+ Container<Number> tridiag_solve (auto && lower, auto && diag, auto && upper, auto && right) {
133+ const auto n = right.size ();
134+ assert (n == lower.size ());
135+ assert (n == diag.size ());
136+ assert (n == upper.size ());
137+ if (n == 0 ) {
138+ return {};
139+ }
140+ if (n == 1 ) {
141+ return {right[0 ]/diag[0 ]};
142+ }
143+ for (SizeT i = 0 ; i < n - 1 ; ++i) {
144+ if (std::abs (diag[i]) >= std::abs (lower[i+1 ])) {
145+ const auto m = lower[i+1 ] / diag[i];
146+ diag[i+1 ] -= m * upper[i];
147+ right[i+1 ] -= m * right[i];
148+ lower[i+1 ] = 0 ;
149+ } else {
150+ // Swap rows i and (i+1).
151+ // Eliminate the lower[i+1] element by reducing row (i+1) by row i.
152+ // Use lower[i+1] as a buffer for the non-tridiagonal element (i,i+2) (hack, used below).
153+ const auto m = diag[i] / lower[i+1 ];
154+ diag[i] = lower[i+1 ];
155+ lower[i+1 ] = upper[i+1 ];
156+ upper[i+1 ] *= -m;
157+ std::swap (upper[i], diag[i+1 ]);
158+ diag[i+1 ] -= m * upper[i];
159+ std::swap (right[i], right[i+1 ]);
160+ right[i+1 ] -= m * right[i];
161+ }
82162 }
83- const auto m_n = ann1 / *(diag - 1 );
84- ann -= m_n * *(upper - 1 );
85- rn -= m_n * *(right - 1 );
86163 Container<Number> X (n);
87- X[n-1 ] = rn / ann;
88- for (auto i = n - 2 ; i > 0 ; --i) {
89- X[i] = (*--right - *--upper * X[i+1 ]) / *--diag;
164+ X[n-1 ] = right[n-1 ] / diag[n-1 ];
165+ X[n-2 ] = (right[n-2 ] - upper[n-2 ] * X[n-1 ]) / diag[n-2 ];
166+ for (SizeT iter = 3 ; iter <= n; ++iter) {
167+ const auto i = n - iter;
168+ X[i] = (right[i] - upper[i] * X[i+1 ] - lower[i+1 ] * X[i+2 ]) / diag[i];
90169 }
91- X[0 ] = (r1 - a12 * X[1 ]) / a11;
92170 return X;
93171 }
94172}
0 commit comments