Skip to content

Commit 7b5b3fc

Browse files
committed
Closes #1586
1 parent 6f8567c commit 7b5b3fc

File tree

6 files changed

+215
-104
lines changed

6 files changed

+215
-104
lines changed

src/Core/Algorithms/Math/SolveLinearSystemWithEigen.cc

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ using namespace SCIRun::Core;
4242

4343
namespace
4444
{
45-
template <class ColumnMatrixType>
45+
template <class ColumnMatrixType, template <typename> class SolverType>
4646
class SolveLinearSystemAlgorithmEigenCGImpl
4747
{
4848
public:
@@ -54,19 +54,19 @@ namespace
5454
template <class MatrixType>
5555
typename ColumnMatrixType::EigenBase solveWithEigen(const MatrixType& lhs)
5656
{
57-
Eigen::ConjugateGradient<typename MatrixType::EigenBase> cg;
58-
cg.compute(lhs);
57+
SolverType<typename MatrixType::EigenBase> solver;
58+
solver.compute(lhs);
5959

60-
if (cg.info() != Eigen::Success)
60+
if (solver.info() != Eigen::Success)
6161
BOOST_THROW_EXCEPTION(AlgorithmInputException()
62-
<< LinearAlgebraErrorMessage("Conjugate gradient initialization was unsuccessful")
63-
<< EigenComputationInfo(cg.info()));
64-
65-
cg.setTolerance(tolerance_);
66-
cg.setMaxIterations(maxIterations_);
67-
auto solution = cg.solve(*rhs_).eval();
68-
tolerance_ = cg.error();
69-
maxIterations_ = cg.iterations();
62+
<< LinearAlgebraErrorMessage("Eigen solver initialization was unsuccessful")
63+
<< EigenComputationInfo(solver.info()));
64+
65+
solver.setTolerance(tolerance_);
66+
solver.setMaxIterations(maxIterations_);
67+
auto solution = solver.solve(*rhs_).eval();
68+
tolerance_ = solver.error();
69+
maxIterations_ = solver.iterations();
7070
return solution;
7171
}
7272

@@ -87,6 +87,14 @@ SolveLinearSystemAlgorithm::ComplexOutputs SolveLinearSystemAlgorithm::run(const
8787
return runImpl<ComplexInputs, ComplexOutputs>(input, params);
8888
}
8989

90+
template <typename T>
91+
using CG = Eigen::ConjugateGradient<T>;
92+
// Not available yet, need to upgrade Eigen
93+
// template <typename T>
94+
// using LSCG = Eigen::LeastSquaresConjugateGradient<T>;
95+
template <typename T>
96+
using BiCG = Eigen::BiCGSTAB<T>;
97+
9098
template <typename In, typename Out>
9199
Out SolveLinearSystemAlgorithm::runImpl(const In& input, const Parameters& params) const
92100
{
@@ -102,8 +110,29 @@ Out SolveLinearSystemAlgorithm::runImpl(const In& input, const Parameters& param
102110
int maxIterations = std::get<1>(params);
103111
ENSURE_POSITIVE_INT(maxIterations, "Max iterations out of range!");
104112

113+
auto method = std::get<2>(params);
114+
105115
using SolutionType = DenseColumnMatrixGeneric<typename std::tuple_element<0, In>::type::element_type::value_type>;
106-
using SolverType = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType>;
116+
using AlgoTypeCG = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType, CG>;
117+
using AlgoTypeBiCG = SolveLinearSystemAlgorithmEigenCGImpl<SolutionType, BiCG>;
118+
119+
if ("cg" == method)
120+
return solve<AlgoTypeCG, In, Out>(input, params);
121+
else if ("bicg" == method)
122+
return solve<AlgoTypeBiCG, In, Out>(input, params);
123+
else
124+
{
125+
BOOST_THROW_EXCEPTION(AlgorithmProcessingException() << ErrorMessage("Need to upgrade Eigen for LSCG."));
126+
}
127+
}
128+
129+
template <typename SolverType, typename In, typename Out>
130+
Out SolveLinearSystemAlgorithm::solve(const In& input, const Parameters& params) const
131+
{
132+
auto A = std::get<0>(input);
133+
auto b = std::get<1>(input);
134+
double tolerance = std::get<0>(params);
135+
int maxIterations = std::get<1>(params);
107136

108137
SolverType impl(b, tolerance, maxIterations);
109138

@@ -123,8 +152,7 @@ Out SolveLinearSystemAlgorithm::runImpl(const In& input, const Parameters& param
123152

124153
if (x.size() != 0)
125154
{
126-
/// @todo: move ctor
127-
auto solution(boost::make_shared<SolutionType>(x));
155+
auto solution(boost::make_shared<typename SolverType::SolutionType>(x));
128156
return Out(solution, impl.tolerance_, impl.maxIterations_);
129157
}
130158
else

src/Core/Algorithms/Math/SolveLinearSystemWithEigen.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ namespace Math {
4545
public:
4646
typedef std::tuple<SCIRun::Core::Datatypes::MatrixHandle, SCIRun::Core::Datatypes::DenseColumnMatrixHandle> Inputs;
4747
typedef std::tuple<SCIRun::Core::Datatypes::ComplexMatrixHandle, SCIRun::Core::Datatypes::ComplexDenseColumnMatrixHandle> ComplexInputs;
48-
typedef std::tuple<double, int> Parameters;
48+
typedef std::tuple<double, int, std::string> Parameters;
4949
typedef std::tuple<SCIRun::Core::Datatypes::DenseColumnMatrixHandle, double, int> Outputs;
5050
typedef std::tuple<SCIRun::Core::Datatypes::ComplexDenseColumnMatrixHandle, double, int> ComplexOutputs;
5151

@@ -56,6 +56,8 @@ namespace Math {
5656
private:
5757
template <typename In, typename Out>
5858
Out runImpl(const In& input, const Parameters& params) const;
59+
template <typename SolverType, typename In, typename Out>
60+
Out solve(const In& input, const Parameters& params) const;
5961
};
6062

6163
typedef boost::error_info<struct tag_eigen_computation, Eigen::ComputationInfo> EigenComputationInfo;

src/Core/Algorithms/Math/Tests/SolveLinearSystemWithEigenTests.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ TEST(EigenSparseSolverTest, DISABLED_CanSolveBigSystem)
344344
ScopedTimer t("using algorithm object");
345345
SolveLinearSystemAlgorithm algo;
346346

347-
x = algo.run(std::make_tuple(A, bCol), std::make_tuple(1e-20, 4000));
347+
x = algo.run(std::make_tuple(A, bCol), std::make_tuple(1e-20, 4000, "cg"));
348348
MatrixHandle solution = std::get<0>(x);
349349

350350
ASSERT_TRUE(solution.get() != nullptr);

0 commit comments

Comments
 (0)