@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
228
228
}
229
229
delete ctx;
230
230
}
231
+
232
+ template <typename T>
233
+ void GemmWarpTest (int m, int n, int k, T alpha, T beta) {
234
+ paddle::framework::Tensor mat_a;
235
+ paddle::framework::Tensor mat_b;
236
+ paddle::framework::Tensor mat_c_ref;
237
+ paddle::framework::Tensor mat_c_mkl;
238
+ auto * cpu_place = new paddle::platform::CPUPlace ();
239
+
240
+ T* A = mat_a.mutable_data <T>({m, k}, *cpu_place);
241
+ T* B = mat_b.mutable_data <T>({k, n}, *cpu_place);
242
+ T* CREF = mat_c_ref.mutable_data <T>({m, n}, *cpu_place);
243
+ T* CMKL = mat_c_mkl.mutable_data <T>({m, n}, *cpu_place);
244
+
245
+ ASSERT_EQ (mat_c_mkl.numel (), mat_c_ref.numel ());
246
+ for (int i = 0 ; i < mat_a.numel (); ++i) {
247
+ A[i] = static_cast <T>(i);
248
+ }
249
+ for (int i = 0 ; i < mat_b.numel (); ++i) {
250
+ B[i] = static_cast <T>(i + 1 );
251
+ }
252
+ for (int i = 0 ; i < mat_c_ref.numel (); ++i) {
253
+ CREF[i] = static_cast <T>(i + 2 );
254
+ CMKL[i] = CREF[i];
255
+ }
256
+
257
+ // this would call gemm_warp
258
+ paddle::platform::CPUDeviceContext context (*cpu_place);
259
+ GetBlas<T>(context).GEMM (CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B,
260
+ beta, CREF);
261
+
262
+ // lda,ldb,ldc follow RowMajor
263
+ int lda = k;
264
+ int ldb = n;
265
+ int ldc = n;
266
+ paddle::operators::math::CBlas<T>::GEMM (CblasRowMajor, CblasNoTrans,
267
+ CblasNoTrans, m, n, k, alpha, A, lda,
268
+ B, ldb, beta, CMKL, ldc);
269
+
270
+ for (int i = 0 ; i < mat_c_mkl.numel (); ++i) {
271
+ EXPECT_FLOAT_EQ (CREF[i], CMKL[i]);
272
+ }
273
+ }
274
+
275
+ TEST (math_function, gemm_warp) {
276
+ GemmWarpTest<float >(3 , 2 , 5 , 1 .f , 0 .f );
277
+ GemmWarpTest<float >(3 , 2 , 5 , 2 .f , 1 .f );
278
+ GemmWarpTest<float >(8 , 5 , 6 , 1 .f , 0 .f );
279
+ GemmWarpTest<float >(8 , 5 , 6 , 2 .f , 1 .f );
280
+ GemmWarpTest<double >(3 , 2 , 5 , 1.0 , 0.0 );
281
+ GemmWarpTest<double >(3 , 2 , 5 , 2.0 , 1.0 );
282
+ GemmWarpTest<double >(8 , 5 , 6 , 1.0 , 0.0 );
283
+ GemmWarpTest<double >(8 , 5 , 6 , 2.0 , 1.0 );
284
+ }
0 commit comments