Skip to content

Commit f7f365a

Browse files
fwyzardsmuzaffar
authored andcommitted
Extend support for self-adjoint matrices on CUDA devices
1 parent 017ce22 commit f7f365a

File tree

7 files changed

+20
-8
lines changed

7 files changed

+20
-8
lines changed

Eigen/src/Core/ProductEvaluators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag>
813813
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
814814

815815
template<typename Dest>
816+
EIGEN_DEVICE_FUNC
816817
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
817818
{
818819
selfadjoint_product_impl<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>::run(dst, lhs, rhs.nestedExpression(), alpha);

Eigen/src/Core/TriangularMatrix.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,8 @@ template< typename DstXprType, typename Lhs, typename Rhs, typename Scalar>
961961
struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::assign_op<Scalar,typename Product<Lhs,Rhs,DefaultProduct>::Scalar>, Dense2Triangular>
962962
{
963963
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType;
964+
965+
EIGEN_DEVICE_FUNC
964966
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,typename SrcXprType::Scalar> &)
965967
{
966968
Index dstRows = src.rows();

Eigen/src/Core/products/GeneralBlockPanelKernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,11 +2361,11 @@ template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pa
23612361
struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
23622362
{
23632363
typedef typename DataMapper::LinearMapper LinearMapper;
2364-
EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2364+
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
23652365
};
23662366

23672367
template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2368-
EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2368+
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
23692369
::operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
23702370
{
23712371
typedef typename unpacket_traits<Packet>::half HalfPacket;
@@ -2624,8 +2624,8 @@ struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMo
26242624
typedef typename DataMapper::LinearMapper LinearMapper;
26252625
enum { PacketSize = packet_traits<Scalar>::size,
26262626
HalfPacketSize = unpacket_traits<HalfPacket>::size,
2627-
QuarterPacketSize = unpacket_traits<QuarterPacket>::size};
2628-
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0)
2627+
QuarterPacketSize = unpacket_traits<QuarterPacket>::size};
2628+
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0)
26292629
{
26302630
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
26312631
EIGEN_UNUSED_VARIABLE(stride);

Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ template <typename Index, typename LhsScalar, int LhsStorageOrder, bool Conjugat
6060
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,UpLo,Version>
6161
{
6262
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
63+
64+
EIGEN_DEVICE_FUNC
6365
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
6466
const RhsScalar* _rhs, Index rhsStride,
6567
ResScalar* _res, Index resIncr, Index resStride,
@@ -144,6 +146,7 @@ struct tribb_kernel
144146
enum {
145147
BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret
146148
};
149+
EIGEN_DEVICE_FUNC
147150
void operator()(ResScalar* _res, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
148151
{
149152
typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
@@ -252,6 +255,7 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true>
252255
template<typename MatrixType, typename ProductType, int UpLo>
253256
struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
254257
{
258+
EIGEN_DEVICE_FUNC
255259
static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta)
256260
{
257261
typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;

Eigen/src/Core/products/SelfadjointMatrixMatrix.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ template<typename Scalar, typename Index, int nr, int StorageOrder>
101101
struct symm_pack_rhs
102102
{
103103
enum { PacketSize = packet_traits<Scalar>::size };
104+
105+
EIGEN_DEVICE_FUNC
104106
void operator()(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
105107
{
106108
Index end_k = k2 + rows;
@@ -423,6 +425,7 @@ template <typename Scalar, typename Index,
423425
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>
424426
{
425427

428+
EIGEN_DEVICE_FUNC
426429
static EIGEN_DONT_INLINE void run(
427430
Index rows, Index cols,
428431
const Scalar* _lhs, Index lhsStride,
@@ -435,7 +438,7 @@ template <typename Scalar, typename Index,
435438
int LhsStorageOrder, bool ConjugateLhs,
436439
int RhsStorageOrder, bool ConjugateRhs,
437440
int ResInnerStride>
438-
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
441+
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
439442
Index rows, Index cols,
440443
const Scalar* _lhs, Index lhsStride,
441444
const Scalar* _rhs, Index rhsStride,
@@ -505,6 +508,7 @@ struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false>
505508
};
506509

507510
template<typename Dest>
511+
EIGEN_DEVICE_FUNC
508512
static void run(Dest &dst, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
509513
{
510514
eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());

Eigen/src/Core/products/SelfadjointMatrixVector.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ struct selfadjoint_product_impl<Lhs,0,true,Rhs,RhsMode,false>
246246
enum { RhsUpLo = RhsMode&(Upper|Lower) };
247247

248248
template<typename Dest>
249-
static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
249+
static EIGEN_DEVICE_FUNC
250+
void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
250251
{
251252
// let's simply transpose the product
252253
Transpose<Dest> destT(dest);

Eigen/src/Core/util/BlasUtil.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,8 @@ struct blas_traits<Transpose<NestedXpr> >
594594
enum {
595595
IsTransposed = Base::IsTransposed ? 0 : 1
596596
};
597-
static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
598-
static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
597+
EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
598+
EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
599599
};
600600

601601
template<typename T>

0 commit comments

Comments
 (0)