Skip to content

Commit 000e000

Browse files
fwyzardsmuzaffar
authored andcommitted
Extend support for matrix inversion on CUDA devices
Extend support for matrix inversion on CUDA devices above 4x4 matrices. The size of the matrices that can be inverted is limited at runtime by the per-thread stack size. For example (building with NVCC and CUDA 9.2): - 9 KB Matrix<double, Dynamic, Dynamic, 0, 9, 9> - 10 KB Matrix<double, Dynamic, Dynamic, 0, 12, 12> - 11 KB Matrix<double, Dynamic, Dynamic, 0, 13, 13> - 12 KB Matrix<double, Dynamic, Dynamic, 0, 15, 15> - 13 KB Matrix<double, Dynamic, Dynamic, 0, 16, 16> - 14 KB Matrix<double, Dynamic, Dynamic, 0, 16, 16> - 15 KB Matrix<double, Dynamic, Dynamic, 0, 17, 17> - 16 KB Matrix<double, Dynamic, Dynamic, 0, 18, 18> - 17 KB Matrix<double, Dynamic, Dynamic, 0, 19, 19> - 18 KB Matrix<double, Dynamic, Dynamic, 0, 20, 20> - 19 KB Matrix<double, Dynamic, Dynamic, 0, 21, 21> - 20 KB Matrix<double, Dynamic, Dynamic, 0, 21, 21> - 21 KB Matrix<double, Dynamic, Dynamic, 0, 22, 22> - 22 KB Matrix<double, Dynamic, Dynamic, 0, 23, 23> - 23 KB Matrix<double, Dynamic, Dynamic, 0, 23, 23> - 24 KB Matrix<double, Dynamic, Dynamic, 0, 24, 24> etc. For dynamic matrices the necessary stack size depends on the maximum matrix size, and is slightly higher than for a fixed-sized matrix of the same size.
1 parent 03770df commit 000e000

File tree

10 files changed

+129
-13
lines changed

10 files changed

+129
-13
lines changed

Eigen/src/Core/MatrixBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ template<typename Derived> class MatrixBase
324324
/////////// LU module ///////////
325325

326326
inline const FullPivLU<PlainObject> fullPivLu() const;
327+
328+
EIGEN_DEVICE_FUNC
327329
inline const PartialPivLU<PlainObject> partialPivLu() const;
328330

329331
inline const PartialPivLU<PlainObject> lu() const;

Eigen/src/Core/PermutationMatrix.h

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class PermutationBase : public EigenBase<Derived>
7171

7272
/** Copies the other permutation into *this */
7373
template<typename OtherDerived>
74+
EIGEN_DEVICE_FUNC
7475
Derived& operator=(const PermutationBase<OtherDerived>& other)
7576
{
7677
indices() = other.indices();
@@ -79,6 +80,7 @@ class PermutationBase : public EigenBase<Derived>
7980

8081
/** Assignment from the Transpositions \a tr */
8182
template<typename OtherDerived>
83+
EIGEN_DEVICE_FUNC
8284
Derived& operator=(const TranspositionsBase<OtherDerived>& tr)
8385
{
8486
setIdentity(tr.size());
@@ -87,6 +89,18 @@ class PermutationBase : public EigenBase<Derived>
8789
return derived();
8890
}
8991

92+
#ifndef EIGEN_PARSED_BY_DOXYGEN
93+
/** This is a special case of the templated operator=. Its purpose is to
94+
* prevent a default operator= from hiding the templated operator=.
95+
*/
96+
EIGEN_DEVICE_FUNC
97+
Derived& operator=(const PermutationBase& other)
98+
{
99+
indices() = other.indices();
100+
return derived();
101+
}
102+
#endif
103+
90104
/** \returns the number of rows */
91105
inline EIGEN_DEVICE_FUNC Index rows() const { return Index(indices().size()); }
92106

@@ -116,18 +130,22 @@ class PermutationBase : public EigenBase<Derived>
116130
}
117131

118132
/** const version of indices(). */
133+
EIGEN_DEVICE_FUNC
119134
const IndicesType& indices() const { return derived().indices(); }
120135
/** \returns a reference to the stored array representing the permutation. */
136+
EIGEN_DEVICE_FUNC
121137
IndicesType& indices() { return derived().indices(); }
122138

123139
/** Resizes to given size.
124140
*/
141+
EIGEN_DEVICE_FUNC
125142
inline void resize(Index newSize)
126143
{
127144
indices().resize(newSize);
128145
}
129146

130147
/** Sets *this to be the identity permutation matrix */
148+
EIGEN_DEVICE_FUNC
131149
void setIdentity()
132150
{
133151
StorageIndex n = StorageIndex(size());
@@ -137,6 +155,7 @@ class PermutationBase : public EigenBase<Derived>
137155

138156
/** Sets *this to be the identity permutation matrix of given size.
139157
*/
158+
EIGEN_DEVICE_FUNC
140159
void setIdentity(Index newSize)
141160
{
142161
resize(newSize);
@@ -171,10 +190,11 @@ class PermutationBase : public EigenBase<Derived>
171190
*
172191
* \sa applyTranspositionOnTheLeft(Index,Index)
173192
*/
193+
EIGEN_DEVICE_FUNC
174194
Derived& applyTranspositionOnTheRight(Index i, Index j)
175195
{
176196
eigen_assert(i>=0 && j>=0 && i<size() && j<size());
177-
std::swap(indices().coeffRef(i), indices().coeffRef(j));
197+
numext::swap(indices().coeffRef(i), indices().coeffRef(j));
178198
return derived();
179199
}
180200

@@ -307,21 +327,31 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
307327
typedef typename Traits::StorageIndex StorageIndex;
308328
#endif
309329

330+
EIGEN_DEVICE_FUNC
310331
inline PermutationMatrix()
311332
{}
312333

313334
/** Constructs an uninitialized permutation matrix of given size.
314335
*/
336+
EIGEN_DEVICE_FUNC
315337
explicit inline PermutationMatrix(Index size) : m_indices(size)
316338
{
317339
eigen_internal_assert(size <= NumTraits<StorageIndex>::highest());
318340
}
319341

320342
/** Copy constructor. */
321343
template<typename OtherDerived>
344+
EIGEN_DEVICE_FUNC
322345
inline PermutationMatrix(const PermutationBase<OtherDerived>& other)
323346
: m_indices(other.indices()) {}
324347

348+
#ifndef EIGEN_PARSED_BY_DOXYGEN
349+
/** Standard copy constructor. Defined only to prevent a default copy constructor
350+
* from hiding the other templated constructor */
351+
EIGEN_DEVICE_FUNC
352+
inline PermutationMatrix(const PermutationMatrix& other) : m_indices(other.indices()) {}
353+
#endif
354+
325355
/** Generic constructor from expression of the indices. The indices
326356
* array has the meaning that the permutations sends each integer i to indices[i].
327357
*
@@ -330,11 +360,13 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
330360
* array's size.
331361
*/
332362
template<typename Other>
363+
EIGEN_DEVICE_FUNC
333364
explicit inline PermutationMatrix(const MatrixBase<Other>& indices) : m_indices(indices)
334365
{}
335366

336367
/** Convert the Transpositions \a tr to a permutation matrix */
337368
template<typename Other>
369+
EIGEN_DEVICE_FUNC
338370
explicit PermutationMatrix(const TranspositionsBase<Other>& tr)
339371
: m_indices(tr.size())
340372
{
@@ -343,6 +375,7 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
343375

344376
/** Copies the other permutation into *this */
345377
template<typename Other>
378+
EIGEN_DEVICE_FUNC
346379
PermutationMatrix& operator=(const PermutationBase<Other>& other)
347380
{
348381
m_indices = other.indices();
@@ -351,17 +384,32 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
351384

352385
/** Assignment from the Transpositions \a tr */
353386
template<typename Other>
387+
EIGEN_DEVICE_FUNC
354388
PermutationMatrix& operator=(const TranspositionsBase<Other>& tr)
355389
{
356390
return Base::operator=(tr.derived());
357391
}
358392

393+
#ifndef EIGEN_PARSED_BY_DOXYGEN
394+
/** This is a special case of the templated operator=. Its purpose is to
395+
* prevent a default operator= from hiding the templated operator=.
396+
*/
397+
EIGEN_DEVICE_FUNC
398+
PermutationMatrix& operator=(const PermutationMatrix& other)
399+
{
400+
m_indices = other.m_indices;
401+
return *this;
402+
}
403+
#endif
404+
359405
/** const version of indices(). */
406+
EIGEN_DEVICE_FUNC
360407
const IndicesType& indices() const { return m_indices; }
408+
361409
/** \returns a reference to the stored array representing the permutation. */
410+
EIGEN_DEVICE_FUNC
362411
IndicesType& indices() { return m_indices; }
363412

364-
365413
/**** multiplication helpers to hopefully get RVO ****/
366414

367415
#ifndef EIGEN_PARSED_BY_DOXYGEN
@@ -374,7 +422,9 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
374422
for (StorageIndex i=0; i<end;++i)
375423
m_indices.coeffRef(other.derived().nestedExpression().indices().coeff(i)) = i;
376424
}
425+
377426
template<typename Lhs,typename Rhs>
427+
EIGEN_DEVICE_FUNC
378428
PermutationMatrix(internal::PermPermProduct_t, const Lhs& lhs, const Rhs& rhs)
379429
: m_indices(lhs.indices().size())
380430
{

Eigen/src/Core/Solve.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ template<typename DstXprType, typename DecType, typename RhsType, typename Scala
137137
struct Assignment<DstXprType, Solve<DecType,RhsType>, internal::assign_op<Scalar,Scalar>, Dense2Dense>
138138
{
139139
typedef Solve<DecType,RhsType> SrcXprType;
140+
EIGEN_DEVICE_FUNC
140141
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
141142
{
142143
Index dstRows = src.rows();

Eigen/src/Core/SolverBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@ class SolverBase : public EigenBase<Derived>
9191
};
9292

9393
/** Default constructor */
94+
EIGEN_DEVICE_FUNC
9495
SolverBase()
9596
{}
9697

98+
EIGEN_DEVICE_FUNC
9799
~SolverBase()
98100
{}
99101

Eigen/src/Core/Transpositions.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,19 @@ class TranspositionsBase
5050
EIGEN_DEVICE_FUNC
5151
inline const StorageIndex& coeff(Index i) const { return indices().coeff(i); }
5252
/** Direct access to the underlying index vector */
53+
EIGEN_DEVICE_FUNC
5354
inline StorageIndex& coeffRef(Index i) { return indices().coeffRef(i); }
5455
/** Direct access to the underlying index vector */
56+
EIGEN_DEVICE_FUNC
5557
inline const StorageIndex& operator()(Index i) const { return indices()(i); }
5658
/** Direct access to the underlying index vector */
59+
EIGEN_DEVICE_FUNC
5760
inline StorageIndex& operator()(Index i) { return indices()(i); }
5861
/** Direct access to the underlying index vector */
62+
EIGEN_DEVICE_FUNC
5963
inline const StorageIndex& operator[](Index i) const { return indices()(i); }
6064
/** Direct access to the underlying index vector */
65+
EIGEN_DEVICE_FUNC
6166
inline StorageIndex& operator[](Index i) { return indices()(i); }
6267

6368
/** const version of indices(). */
@@ -68,12 +73,14 @@ class TranspositionsBase
6873
IndicesType& indices() { return derived().indices(); }
6974

7075
/** Resizes to given size. */
76+
EIGEN_DEVICE_FUNC
7177
inline void resize(Index newSize)
7278
{
7379
indices().resize(newSize);
7480
}
7581

7682
/** Sets \c *this to represents an identity transformation */
83+
EIGEN_DEVICE_FUNC
7784
void setIdentity()
7885
{
7986
for(StorageIndex i = 0; i < indices().size(); ++i)
@@ -161,33 +168,51 @@ class Transpositions : public TranspositionsBase<Transpositions<SizeAtCompileTim
161168
typedef typename Traits::IndicesType IndicesType;
162169
typedef typename IndicesType::Scalar StorageIndex;
163170

171+
EIGEN_DEVICE_FUNC
164172
inline Transpositions() {}
165173

166174
/** Copy constructor. */
167175
template<typename OtherDerived>
176+
EIGEN_DEVICE_FUNC
168177
inline Transpositions(const TranspositionsBase<OtherDerived>& other)
169178
: m_indices(other.indices()) {}
170179

171180
/** Generic constructor from expression of the transposition indices. */
172181
template<typename Other>
182+
EIGEN_DEVICE_FUNC
173183
explicit inline Transpositions(const MatrixBase<Other>& indices) : m_indices(indices)
174184
{}
175185

176186
/** Copies the \a other transpositions into \c *this */
177187
template<typename OtherDerived>
188+
EIGEN_DEVICE_FUNC
178189
Transpositions& operator=(const TranspositionsBase<OtherDerived>& other)
179190
{
180191
return Base::operator=(other);
181192
}
182193

194+
#ifndef EIGEN_PARSED_BY_DOXYGEN
195+
/** This is a special case of the templated operator=. Its purpose is to
196+
* prevent a default operator= from hiding the templated operator=.
197+
*/
198+
EIGEN_DEVICE_FUNC
199+
Transpositions& operator=(const Transpositions& other)
200+
{
201+
m_indices = other.m_indices;
202+
return *this;
203+
}
204+
#endif
205+
183206
/** Constructs an uninitialized permutation matrix of given size.
184207
*/
208+
EIGEN_DEVICE_FUNC
185209
inline Transpositions(Index size) : m_indices(size)
186210
{}
187211

188212
/** const version of indices(). */
189213
EIGEN_DEVICE_FUNC
190214
const IndicesType& indices() const { return m_indices; }
215+
191216
/** \returns a reference to the stored array representing the transpositions. */
192217
EIGEN_DEVICE_FUNC
193218
IndicesType& indices() { return m_indices; }

Eigen/src/Core/products/GeneralMatrixMatrix.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,17 @@ class level3_blocking
263263

264264
public:
265265

266+
EIGEN_DEVICE_FUNC
266267
level3_blocking()
267268
: m_blockA(0), m_blockB(0), m_mc(0), m_nc(0), m_kc(0)
268269
{}
269270

270-
inline Index mc() const { return m_mc; }
271-
inline Index nc() const { return m_nc; }
272-
inline Index kc() const { return m_kc; }
271+
EIGEN_DEVICE_FUNC inline Index mc() const { return m_mc; }
272+
EIGEN_DEVICE_FUNC inline Index nc() const { return m_nc; }
273+
EIGEN_DEVICE_FUNC inline Index kc() const { return m_kc; }
273274

274-
inline LhsScalar* blockA() { return m_blockA; }
275-
inline RhsScalar* blockB() { return m_blockB; }
275+
EIGEN_DEVICE_FUNC inline LhsScalar* blockA() { return m_blockA; }
276+
EIGEN_DEVICE_FUNC inline RhsScalar* blockB() { return m_blockB; }
276277
};
277278

278279
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
@@ -304,6 +305,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
304305

305306
public:
306307

308+
EIGEN_DEVICE_FUNC
307309
gemm_blocking_space(Index /*rows*/, Index /*cols*/, Index /*depth*/, Index /*num_threads*/, bool /*full_rows = false*/)
308310
{
309311
this->m_mc = ActualRows;
@@ -344,6 +346,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
344346

345347
public:
346348

349+
EIGEN_DEVICE_FUNC
347350
gemm_blocking_space(Index rows, Index cols, Index depth, Index num_threads, bool l3_blocking)
348351
{
349352
this->m_mc = Transpose ? cols : rows;
@@ -395,6 +398,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
395398
allocateB();
396399
}
397400

401+
EIGEN_DEVICE_FUNC
398402
~gemm_blocking_space()
399403
{
400404
aligned_delete(this->m_blockA, m_sizeA);

Eigen/src/Core/products/TriangularSolverMatrix.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,
3838
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
3939
struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
4040
{
41-
static EIGEN_DONT_INLINE void run(
41+
static EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void run(
4242
Index size, Index otherSize,
4343
const Scalar* _tri, Index triStride,
4444
Scalar* _other, Index otherIncr, Index otherStride,
4545
level3_blocking<Scalar,Scalar>& blocking);
4646
};
4747
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
48-
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
48+
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
4949
Index size, Index otherSize,
5050
const Scalar* _tri, Index triStride,
5151
Scalar* _other, Index otherIncr, Index otherStride,
@@ -61,7 +61,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
6161
typedef gebp_traits<Scalar,Scalar> Traits;
6262

6363
enum {
64-
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
64+
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
6565
IsLower = (Mode&Lower) == Lower
6666
};
6767

Eigen/src/Core/util/BlasUtil.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,11 @@ template<> struct conj_if<true> {
5151

5252
template<> struct conj_if<false> {
5353
template<typename T>
54+
EIGEN_DEVICE_FUNC
5455
inline const T& operator()(const T& x) const { return x; }
56+
5557
template<typename T>
58+
EIGEN_DEVICE_FUNC
5659
inline const T& pconj(const T& x) const { return x; }
5760
};
5861

@@ -472,9 +475,9 @@ class blas_data_mapper
472475
template<typename Scalar, typename Index, int StorageOrder>
473476
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
474477
public:
475-
EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
478+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
476479

477-
EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
480+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
478481
return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
479482
}
480483
};

0 commit comments

Comments
 (0)