2727#define _H_GRB_ALGORITHMS_CONJUGATE_GRADIENT
2828
2929#include < cstdio>
30- #include < cmath >
30+ #include < complex >
3131
3232#include < graphblas.hpp>
3333#include < graphblas/utils/iscomplex.hpp>
@@ -144,8 +144,7 @@ namespace grb {
144144 * performance semantics, with the exception of getters such as #grb::nnz, are
145145 * specific to the backend selected during compilation.
146146 */
147- template <
148- Descriptor descr = descriptors::no_operation,
147+ template < Descriptor descr = descriptors::no_operation,
149148 typename IOType,
150149 typename ResidualType,
151150 typename NonzeroType,
@@ -155,20 +154,19 @@ namespace grb {
155154 grb::identities::zero, grb::identities::one
156155 >,
157156 class Minus = operators::subtract< IOType >,
158- class Divide = operators::divide< IOType >,
159- typename RSI, typename NZI, Backend backend
157+ class Divide = operators::divide< IOType >
160158 >
161159 grb::RC conjugate_gradient (
162- grb::Vector< IOType, backend > &x,
163- const grb::Matrix< NonzeroType, backend, RSI, RSI, NZI > &A,
164- const grb::Vector< InputType, backend > &b,
160+ grb::Vector< IOType > &x,
161+ const grb::Matrix< NonzeroType > &A,
162+ const grb::Vector< InputType > &b,
165163 const size_t max_iterations,
166164 ResidualType tol,
167165 size_t &iterations,
168166 ResidualType &residual,
169- grb::Vector< IOType, backend > &r,
170- grb::Vector< IOType, backend > &u,
171- grb::Vector< IOType, backend > &temp,
167+ grb::Vector< IOType > &r,
168+ grb::Vector< IOType > &u,
169+ grb::Vector< IOType > &temp,
172170 const Ring &ring = Ring(),
173171 const Minus &minus = Minus(),
174172 const Divide ÷ = Divide()
@@ -326,7 +324,7 @@ namespace grb {
326324 assert ( ret == SUCCESS );
327325
328326 if ( ret == SUCCESS ) {
329- tol *= std:: sqrt ( grb::utils::is_complex< IOType >::modulus ( bnorm ) );
327+ tol *= sqrt ( grb::utils::is_complex< IOType >::modulus ( bnorm ) );
330328 }
331329
332330 size_t iter = 0 ;
@@ -419,7 +417,7 @@ namespace grb {
419417
420418 // return correct error code
421419 if ( ret == SUCCESS ) {
422- if ( std:: sqrt ( residual ) >= tol ) {
420+ if ( sqrt ( residual ) >= tol ) {
423421 // did not converge within iterations
424422 return FAILED;
425423 }
0 commit comments