Skip to content

Commit 64a8e6d

Browse files
committed
refine the threshold functions
1 parent 32822b2 commit 64a8e6d

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

paddle/fluid/operators/math/blas_impl.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#pragma once
15+
#include <limits>
1516
#include <vector>
1617
#include "paddle/fluid/operators/math/math_function.h"
1718

@@ -161,6 +162,25 @@ struct CBlas<platform::float16> {
161162
}
162163
#endif
163164
};
165+
template <typename T>
166+
inline static bool UseXSMM(const int &m, const int &n, const int &k,
167+
bool transa, bool transb, const T &alpha,
168+
const T &beta) {
169+
#ifdef PADDLE_WITH_LIBXSMM
170+
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
171+
// But the threshold is custom
172+
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
173+
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
174+
std::abs<T>(alpha - static_cast<T>(1) >
175+
std::numeric_limits<T>::epsilon()) ||
176+
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
177+
return false;
178+
} else {
179+
return true;
180+
}
181+
#endif
182+
return false;
183+
}
164184

165185
template <>
166186
template <typename T>
@@ -172,8 +192,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
172192
int ldb = (transB == CblasNoTrans) ? N : K;
173193
int ldc = N;
174194
#ifdef PADDLE_WITH_LIBXSMM
175-
if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans &&
176-
transB == CblasNoTrans) {
195+
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
196+
beta)) {
177197
// refer to https://github.com/hfp/libxsmm/blob/master/README.md
178198
// Note: SMM use ColMajor
179199
const char transa = 'N';

0 commit comments

Comments
 (0)