Skip to content

Commit f0cd493

Browse files
authored
Merge pull request #11989 from tensor-tang/feature/libxsmm
introduce libxsmm
2 parents 3a769b9 + 2f7b093 commit f0cd493

File tree

6 files changed

+186
-3
lines changed

6 files changed

+186
-3
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ endif()
146146
########################################################################################
147147

148148
include(external/mklml) # download mklml package
149+
include(external/libxsmm) # download, build, install libxsmm
149150
include(external/zlib) # download, build, install zlib
150151
include(external/gflags) # download, build, install gflags
151152
include(external/glog) # download, build, install glog
@@ -232,6 +233,10 @@ if(WITH_MKLML)
232233
list(APPEND EXTERNAL_LIBS ${MKLML_IOMP_LIB})
233234
endif()
234235

236+
if(WITH_LIBXSMM)
237+
list(APPEND EXTERNAL_LIBS ${LIBXSMM_LIBS})
238+
endif()
239+
235240
if(WITH_MKLDNN)
236241
list(APPEND EXTERNAL_LIBS ${MKLDNN_LIB})
237242
endif()

cmake/external/libxsmm.cmake

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
OPTION(WITH_LIBXSMM "Compile with libxsmm" OFF)
17+
18+
IF(NOT WITH_LIBXSMM)
19+
return()
20+
ENDIF()
21+
22+
IF(WIN32 OR APPLE OR ANDROID OR IOS)
23+
MESSAGE(WARNING "Windows, Mac or Mobile are not supported with libxsmm in Paddle yet.")
24+
SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM" FORCE)
25+
return()
26+
ENDIF()
27+
28+
INCLUDE (ExternalProject)
29+
30+
SET(LIBXSMM_SOURCES_DIR ${THIRD_PARTY_PATH}/libxsmm)
31+
SET(LIBXSMM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/libxsmm)
32+
SET(LIBXSMM_INCLUDE_DIR "${LIBXSMM_INSTALL_DIR}/include" CACHE PATH "LIBXSMM include directory." FORCE)
33+
SET(LIBXSMM_LIBRARY_DIR "${LIBXSMM_INSTALL_DIR}/lib" CACHE PATH "LIBXSMM library directory." FORCE)
34+
SET(LIBXSMM_LIBS "${LIBXSMM_LIBRARY_DIR}/libxsmm.a"
35+
"${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a")
36+
37+
ExternalProject_Add(
38+
extern_libxsmm
39+
GIT_REPOSITORY "https://github.com/hfp/libxsmm.git"
40+
GIT_TAG "7cc03b5b342fdbc6b6d990b190671c5dbb8489a2"
41+
PREFIX ${LIBXSMM_SOURCES_DIR}
42+
UPDATE_COMMAND ""
43+
CONFIGURE_COMMAND ""
44+
BUILD_IN_SOURCE 1
45+
BUILD_COMMAND $(MAKE) --silent PREFIX=${LIBXSMM_INSTALL_DIR} CXX=g++ CC=gcc WARP=0 install
46+
INSTALL_COMMAND ""
47+
)
48+
ADD_LIBRARY(libxsmm STATIC IMPORTED GLOBAL)
49+
SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmm.a")
50+
SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a")
51+
52+
MESSAGE(STATUS "Libxsmm library: ${LIBXSMM_LIBS}")
53+
include_directories(${LIBXSMM_INCLUDE_DIR})
54+
ADD_DEFINITIONS(-DPADDLE_WITH_LIBXSMM)
55+
ADD_DEPENDENCIES(libxsmm extern_libxsmm)
56+
LIST(APPEND external_project_dependencies libxsmm)
57+

cmake/external/openblas.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ ELSE()
121121
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
122122
ENDIF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
123123

124+
IF(WITH_LIBXSMM)
125+
TARGET_LINK_LIBRARIES(cblas ${LIBXSMM_LIBS})
126+
ADD_DEPENDENCIES(cblas extern_libxsmm)
127+
ENDIF()
128+
124129
IF(NOT ${CBLAS_FOUND})
125130
ADD_DEPENDENCIES(cblas extern_openblas)
126131
LIST(APPEND external_project_dependencies cblas)

paddle/fluid/operators/math/blas.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
#include "paddle/fluid/platform/dynload/mklml.h"
2222
#endif
2323

24+
#ifdef PADDLE_WITH_LIBXSMM
25+
#include <libxsmm.h>
26+
#endif
27+
2428
#ifdef PADDLE_USE_OPENBLAS
2529
#include <cblas.h>
2630
#endif

paddle/fluid/operators/math/blas_impl.h

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#pragma once
15+
#include <limits>
1516
#include <vector>
1617
#include "paddle/fluid/operators/math/math_function.h"
1718

@@ -30,6 +31,12 @@ struct CBlas<float> {
3031
platform::dynload::cblas_sgemm(args...);
3132
}
3233

34+
#ifdef PADDLE_WITH_LIBXSMM
35+
template <typename... ARGS>
36+
static void SMM_GEMM(ARGS... args) {
37+
libxsmm_sgemm(args...);
38+
}
39+
#endif
3340
template <typename... ARGS>
3441
static void AXPY(ARGS... args) {
3542
platform::dynload::cblas_saxpy(args...);
@@ -63,6 +70,12 @@ struct CBlas<double> {
6370
platform::dynload::cblas_dgemm(args...);
6471
}
6572

73+
#ifdef PADDLE_WITH_LIBXSMM
74+
template <typename... ARGS>
75+
static void SMM_GEMM(ARGS... args) {
76+
libxsmm_dgemm(args...);
77+
}
78+
#endif
6679
template <typename... ARGS>
6780
static void AXPY(ARGS... args) {
6881
platform::dynload::cblas_daxpy(args...);
@@ -140,13 +153,43 @@ struct CBlas<double> {
140153
template <>
141154
struct CBlas<platform::float16> {
142155
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
156+
static void SMM_GEMM(...) {
157+
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
158+
}
143159
#ifdef PADDLE_WITH_MKLML
144160
static void GEMM_BATCH(...) {
145161
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
146162
}
147163
#endif
148164
};
149165

166+
template <typename T>
167+
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
168+
bool transb, const T &alpha, const T &beta) {
169+
#ifdef PADDLE_WITH_LIBXSMM
170+
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
171+
// But the threshold is custom
172+
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
173+
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
174+
std::abs<T>(alpha - static_cast<T>(1) >
175+
std::numeric_limits<T>::epsilon()) ||
176+
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
177+
return false;
178+
} else {
179+
return true;
180+
}
181+
#endif
182+
return false;
183+
}
184+
185+
template <>
186+
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
187+
bool transa, bool transb,
188+
const platform::float16 &alpha,
189+
const platform::float16 &beta) {
190+
return false;
191+
}
192+
150193
template <>
151194
template <typename T>
152195
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
@@ -156,8 +199,21 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
156199
int lda = (transA == CblasNoTrans) ? K : M;
157200
int ldb = (transB == CblasNoTrans) ? N : K;
158201
int ldc = N;
159-
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
160-
beta, C, ldc);
202+
#ifdef PADDLE_WITH_LIBXSMM
203+
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
204+
beta)) {
205+
// Note: SMM use ColMajor
206+
const char transa = 'N';
207+
const char transb = 'N';
208+
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
209+
&beta, C, &ldc);
210+
} else {
211+
#endif
212+
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
213+
ldb, beta, C, ldc);
214+
#ifdef PADDLE_WITH_LIBXSMM
215+
}
216+
#endif
161217
}
162218

163219
template <>

paddle/fluid/operators/math/math_function_test.cc

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,64 @@ TEST(math_function, gemm_notrans_cblas) {
5454
EXPECT_EQ(input3_ptr[6], 86);
5555
EXPECT_EQ(input3_ptr[7], 99);
5656
}
57+
#ifdef PADDLE_WITH_LIBXSMM
58+
template <typename T>
59+
void MklSmmCompare(int m, int n, int k) {
60+
paddle::framework::Tensor mat_a;
61+
paddle::framework::Tensor mat_b;
62+
paddle::framework::Tensor mat_c_smm;
63+
paddle::framework::Tensor mat_c_mkl;
64+
auto* cpu_place = new paddle::platform::CPUPlace();
65+
66+
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
67+
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
68+
T* CSMM = mat_c_smm.mutable_data<T>({m, n}, *cpu_place);
69+
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
70+
T alpha = static_cast<T>(1);
71+
T beta = static_cast<T>(0);
72+
for (int i = 0; i < mat_a.numel(); ++i) {
73+
A[i] = static_cast<T>(i);
74+
}
75+
for (int i = 0; i < mat_b.numel(); ++i) {
76+
B[i] = static_cast<T>(i);
77+
}
78+
// lda,ldb,ldc follow RowMajor
79+
int lda = k;
80+
int ldb = n;
81+
int ldc = n;
82+
83+
auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
84+
const char transa = 'N';
85+
const char transb = 'N';
86+
paddle::operators::math::CBlas<T>::SMM_GEMM(&transa, &transb, &n, &m, &k,
87+
&alpha, B, &ldb, A, &lda, &beta,
88+
CSMM, &ldc);
89+
};
90+
91+
auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
92+
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
93+
CblasNoTrans, m, n, k, alpha, A,
94+
lda, B, ldb, beta, CMKL, ldc);
95+
};
96+
97+
smm();
98+
mkl();
99+
ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel());
100+
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
101+
EXPECT_FLOAT_EQ(CSMM[i], CMKL[i]);
102+
}
103+
}
104+
TEST(math_function, gemm_mkl_vs_smm) {
105+
MklSmmCompare<float>(1, 2, 3);
106+
MklSmmCompare<double>(1, 2, 3);
107+
MklSmmCompare<float>(3, 2, 1);
108+
MklSmmCompare<double>(3, 2, 1);
109+
MklSmmCompare<float>(3, 8, 5);
110+
MklSmmCompare<double>(3, 8, 5);
111+
}
112+
#endif
57113

58-
TEST(math_function, gemm_trans_clbas) {
114+
TEST(math_function, gemm_trans_cblas) {
59115
paddle::framework::Tensor input1;
60116
paddle::framework::Tensor input2;
61117
paddle::framework::Tensor input3;

0 commit comments

Comments
 (0)