@@ -42,7 +42,7 @@ using namespace SCIRun::Core;
4242
4343namespace
4444{
45- template <class ColumnMatrixType >
45+ template <class ColumnMatrixType , template < typename > class SolverType >
4646 class SolveLinearSystemAlgorithmEigenCGImpl
4747 {
4848 public:
@@ -54,19 +54,19 @@ namespace
5454 template <class MatrixType >
5555 typename ColumnMatrixType::EigenBase solveWithEigen (const MatrixType& lhs)
5656 {
57- Eigen::ConjugateGradient <typename MatrixType::EigenBase> cg ;
58- cg .compute (lhs);
57+ SolverType <typename MatrixType::EigenBase> solver ;
58+ solver .compute (lhs);
5959
60- if (cg .info () != Eigen::Success)
60+ if (solver .info () != Eigen::Success)
6161 BOOST_THROW_EXCEPTION (AlgorithmInputException ()
62- << LinearAlgebraErrorMessage (" Conjugate gradient initialization was unsuccessful" )
63- << EigenComputationInfo (cg .info ()));
64-
65- cg .setTolerance (tolerance_);
66- cg .setMaxIterations (maxIterations_);
67- auto solution = cg .solve (*rhs_).eval ();
68- tolerance_ = cg .error ();
69- maxIterations_ = cg .iterations ();
62+ << LinearAlgebraErrorMessage (" Eigen solver initialization was unsuccessful" )
63+ << EigenComputationInfo (solver .info ()));
64+
65+ solver .setTolerance (tolerance_);
66+ solver .setMaxIterations (maxIterations_);
67+ auto solution = solver .solve (*rhs_).eval ();
68+ tolerance_ = solver .error ();
69+ maxIterations_ = solver .iterations ();
7070 return solution;
7171 }
7272
@@ -87,6 +87,14 @@ SolveLinearSystemAlgorithm::ComplexOutputs SolveLinearSystemAlgorithm::run(const
8787 return runImpl<ComplexInputs, ComplexOutputs>(input, params);
8888}
8989
90+ template <typename T>
91+ using CG = Eigen::ConjugateGradient<T>;
92+ // Not available yet, need to upgrade Eigen
93+ // template <typename T>
94+ // using LSCG = Eigen::LeastSquaresConjugateGradient<T>;
95+ template <typename T>
96+ using BiCG = Eigen::BiCGSTAB<T>;
97+
9098template <typename In, typename Out>
9199Out SolveLinearSystemAlgorithm::runImpl (const In& input, const Parameters& params) const
92100{
@@ -102,8 +110,29 @@ Out SolveLinearSystemAlgorithm::runImpl(const In& input, const Parameters& param
102110 int maxIterations = std::get<1 >(params);
103111 ENSURE_POSITIVE_INT (maxIterations, " Max iterations out of range!" );
104112
113+ auto method = std::get<2 >(params);
114+
105115 using SolutionType = DenseColumnMatrixGeneric<typename std::tuple_element<0 , In>::type::element_type::value_type>;
106- using SolverType = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType>;
116+ using AlgoTypeCG = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType, CG>;
117+ using AlgoTypeBiCG = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType, BiCG>;
118+
119+ if (" cg" == method)
120+ return solve<AlgoTypeCG, In, Out>(input, params);
121+ else if (" bicg" == method)
122+ return solve<AlgoTypeBiCG, In, Out>(input, params);
123+ else
124+ {
125+ BOOST_THROW_EXCEPTION (AlgorithmProcessingException () << ErrorMessage (" Need to upgrade Eigen for LSCG." ));
126+ }
127+ }
128+
129+ template <typename SolverType, typename In, typename Out>
130+ Out SolveLinearSystemAlgorithm::solve (const In& input, const Parameters& params) const
131+ {
132+ auto A = std::get<0 >(input);
133+ auto b = std::get<1 >(input);
134+ double tolerance = std::get<0 >(params);
135+ int maxIterations = std::get<1 >(params);
107136
108137 SolverType impl (b, tolerance, maxIterations);
109138
@@ -123,8 +152,7 @@ Out SolveLinearSystemAlgorithm::runImpl(const In& input, const Parameters& param
123152
124153 if (x.size () != 0 )
125154 {
126- // / @todo: move ctor
127- auto solution (boost::make_shared<SolutionType>(x));
155+ auto solution (boost::make_shared<typename SolverType::SolutionType>(x));
128156 return Out (solution, impl.tolerance_ , impl.maxIterations_ );
129157 }
130158 else
0 commit comments