Skip to content

Commit 42708de

Browse files
authored
Enable the case N != ldc in EigenBlasGemm. (#5976)
* Enable the case N != ldc in EigenBlasGemm. * Use MemoryHandle instead of direct calling of posix_memalign to alloc temporary memory. * Use Eigen's slice() instead of a temporary memory. * Add if-else for different cases in EigenBlasGemm (for N ?= ldc).
1 parent 5f0d081 commit 42708de

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

paddle/function/EigenGemm.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ template <class T>
2121
struct EigenBlasGemm {
2222
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, int>,
2323
Eigen::Aligned>
24-
Matrix;
24+
EigenMatrix;
2525

2626
static void compute(const bool transA,
2727
const bool transB,
@@ -56,14 +56,13 @@ struct EigenBlasGemm {
5656
sizeB[1] = N;
5757
CHECK_EQ(N, ldb);
5858
}
59-
Eigen::array<int, 2> sizeC;
60-
sizeC[0] = M;
61-
sizeC[1] = N;
62-
CHECK_EQ(N, ldc);
59+
Eigen::array<int, 2> sizeC = {{M, ldc}};
60+
Eigen::array<int, 2> offsetC = {{0, 0}};
61+
Eigen::array<int, 2> extentC = {{M, N}};
6362

64-
const Matrix a(const_cast<T*>(A), sizeA);
65-
const Matrix b(const_cast<T*>(B), sizeB);
66-
Matrix c(C, sizeC);
63+
const EigenMatrix a(const_cast<T*>(A), sizeA);
64+
const EigenMatrix b(const_cast<T*>(B), sizeB);
65+
EigenMatrix c(C, sizeC);
6766

6867
typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
6968
Eigen::array<DimPair, 1> dims;
@@ -72,12 +71,23 @@ struct EigenBlasGemm {
7271
dims[0].second = transB ? 1 : 0;
7372

7473
Eigen::DefaultDevice device;
75-
if (alpha == T(1) && beta == T(0)) {
76-
c.device(device) = a.contract(b, dims);
77-
} else if (alpha == T(1) && beta == T(1)) {
78-
c.device(device) += a.contract(b, dims);
74+
if (N == ldc) {
75+
if (alpha == T(1) && beta == T(0)) {
76+
c.device(device) = a.contract(b, dims);
77+
} else if (alpha == T(1) && beta == T(1)) {
78+
c.device(device) += a.contract(b, dims);
79+
} else {
80+
c.device(device) = alpha * a.contract(b, dims) + beta * c;
81+
}
7982
} else {
80-
c.device(device) = alpha * a.contract(b, dims) + beta * c;
83+
if (alpha == T(1) && beta == T(0)) {
84+
c.slice(offsetC, extentC).device(device) = a.contract(b, dims);
85+
} else if (alpha == T(1) && beta == T(1)) {
86+
c.slice(offsetC, extentC).device(device) += a.contract(b, dims);
87+
} else {
88+
c.slice(offsetC, extentC).device(device) =
89+
alpha * a.contract(b, dims) + beta * c.slice(offsetC, extentC);
90+
}
8191
}
8292
}
8393
};

0 commit comments

Comments
 (0)