3434#include < Core/Datatypes/MatrixTypeConversions.h>
3535#include < Eigen/Sparse>
3636
37+ using namespace SCIRun ;
3738using namespace SCIRun ::Core::Algorithms::Math;
3839using namespace SCIRun ::Core::Datatypes;
3940using namespace SCIRun ::Core::Algorithms;
4041using namespace SCIRun ::Core;
4142
4243namespace
4344{
44- template <class ColumnMatrixType >
45+ template <class ColumnMatrixType , template < typename > class SolverType >
4546 class SolveLinearSystemAlgorithmEigenCGImpl
4647 {
4748 public:
48- SolveLinearSystemAlgorithmEigenCGImpl (const ColumnMatrixType& rhs, double tolerance, int maxIterations) :
49+ SolveLinearSystemAlgorithmEigenCGImpl (SharedPointer< ColumnMatrixType> rhs, double tolerance, int maxIterations) :
4950 tolerance_ (tolerance), maxIterations_(maxIterations), rhs_(rhs) {}
5051
5152 using SolutionType = ColumnMatrixType;
5253
5354 template <class MatrixType >
5455 typename ColumnMatrixType::EigenBase solveWithEigen (const MatrixType& lhs)
5556 {
56- Eigen::ConjugateGradient <typename MatrixType::EigenBase> cg ;
57- cg .compute (lhs);
57+ SolverType <typename MatrixType::EigenBase> solver ;
58+ solver .compute (lhs);
5859
59- if (cg .info () != Eigen::Success)
60+ if (solver .info () != Eigen::Success)
6061 BOOST_THROW_EXCEPTION (AlgorithmInputException ()
61- << LinearAlgebraErrorMessage (" Conjugate gradient initialization was unsuccessful" )
62- << EigenComputationInfo (cg .info ()));
63-
64- cg .setTolerance (tolerance_);
65- cg .setMaxIterations (maxIterations_);
66- auto solution = cg .solve (rhs_).eval ();
67- tolerance_ = cg .error ();
68- maxIterations_ = cg .iterations ();
62+ << LinearAlgebraErrorMessage (" Eigen solver initialization was unsuccessful" )
63+ << EigenComputationInfo (solver .info ()));
64+
65+ solver .setTolerance (tolerance_);
66+ solver .setMaxIterations (maxIterations_);
67+ auto solution = solver .solve (* rhs_).eval ();
68+ tolerance_ = solver .error ();
69+ maxIterations_ = solver .iterations ();
6970 return solution;
7071 }
7172
7273 double tolerance_;
7374 int maxIterations_;
7475 private:
75- const ColumnMatrixType& rhs_;
76+ SharedPointer< ColumnMatrixType> rhs_;
7677 };
7778}
7879
@@ -86,6 +87,14 @@ SolveLinearSystemAlgorithm::ComplexOutputs SolveLinearSystemAlgorithm::run(const
8687 return runImpl<ComplexInputs, ComplexOutputs>(input, params);
8788}
8889
90+ template <typename T>
91+ using CG = Eigen::ConjugateGradient<T>;
92+ // Not available yet, need to upgrade Eigen
93+ // template <typename T>
94+ // using LSCG = Eigen::LeastSquaresConjugateGradient<T>;
95+ template <typename T>
96+ using BiCG = Eigen::BiCGSTAB<T>;
97+
8998template <typename In, typename Out>
9099Out SolveLinearSystemAlgorithm::runImpl (const In& input, const Parameters& params) const
91100{
@@ -101,9 +110,32 @@ Out SolveLinearSystemAlgorithm::runImpl(const In& input, const Parameters& param
101110 int maxIterations = std::get<1 >(params);
102111 ENSURE_POSITIVE_INT (maxIterations, " Max iterations out of range!" );
103112
104- using SolutionType = DenseMatrixGeneric<typename std::tuple_element<0 , In>::type::element_type::value_type>;
105- using SolverType = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType>;
106- SolverType impl (*b, tolerance, maxIterations);
113+ auto method = std::get<2 >(params);
114+
115+ using SolutionType = DenseColumnMatrixGeneric<typename std::tuple_element<0 , In>::type::element_type::value_type>;
116+ using AlgoTypeCG = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType, CG>;
117+ using AlgoTypeBiCG = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType, BiCG>;
118+
119+ if (" cg" == method)
120+ return solve<AlgoTypeCG, In, Out>(input, params);
121+ else if (" bicg" == method)
122+ return solve<AlgoTypeBiCG, In, Out>(input, params);
123+ else
124+ {
125+ BOOST_THROW_EXCEPTION (AlgorithmProcessingException () << ErrorMessage (" Need to upgrade Eigen for LSCG." ));
126+ }
127+ }
128+
129+ template <typename SolverType, typename In, typename Out>
130+ Out SolveLinearSystemAlgorithm::solve (const In& input, const Parameters& params) const
131+ {
132+ auto A = std::get<0 >(input);
133+ auto b = std::get<1 >(input);
134+ double tolerance = std::get<0 >(params);
135+ int maxIterations = std::get<1 >(params);
136+
137+ SolverType impl (b, tolerance, maxIterations);
138+
107139 typename SolverType::SolutionType x;
108140 if (matrixIs::dense (A))
109141 {
@@ -120,8 +152,7 @@ Out SolveLinearSystemAlgorithm::runImpl(const In& input, const Parameters& param
120152
121153 if (x.size () != 0 )
122154 {
123- // / @todo: move ctor
124- auto solution (boost::make_shared<SolutionType>(x));
155+ auto solution (boost::make_shared<typename SolverType::SolutionType>(x));
125156 return Out (solution, impl.tolerance_ , impl.maxIterations_ );
126157 }
127158 else
0 commit comments