12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
#pragma once
15
+ #include < limits>
15
16
#include < vector>
16
17
#include " paddle/fluid/operators/math/math_function.h"
17
18
@@ -161,6 +162,25 @@ struct CBlas<platform::float16> {
161
162
}
162
163
#endif
163
164
};
165
+ template <typename T>
166
+ inline static bool UseXSMM (const int &m, const int &n, const int &k,
167
+ bool transa, bool transb, const T &alpha,
168
+ const T &beta) {
169
+ #ifdef PADDLE_WITH_LIBXSMM
170
+ // Refer to https://github.com/hfp/libxsmm/blob/master/README.md
171
+ // But the threshold is custom
172
+ constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20 ;
173
+ if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
174
+ std::abs<T>(alpha - static_cast <T>(1 ) >
175
+ std::numeric_limits<T>::epsilon ()) ||
176
+ std::abs<T>(beta) > std::numeric_limits<T>::epsilon ()) {
177
+ return false ;
178
+ } else {
179
+ return true ;
180
+ }
181
+ #endif
182
+ return false ;
183
+ }
164
184
165
185
template <>
166
186
template <typename T>
@@ -172,8 +192,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
172
192
int ldb = (transB == CblasNoTrans) ? N : K;
173
193
int ldc = N;
174
194
#ifdef PADDLE_WITH_LIBXSMM
175
- if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans &&
176
- transB == CblasNoTrans ) {
195
+ if (UseXSMM (M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
196
+ beta) ) {
177
197
// refer to https://github.com/hfp/libxsmm/blob/master/README.md
178
198
// Note: SMM use ColMajor
179
199
const char transa = ' N' ;
0 commit comments