Skip to content

Commit fc2b578

Browse files
committed
add gemm_warp test
1 parent a916c52 commit fc2b578

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

paddle/fluid/operators/math/math_function_test.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
228228
}
229229
delete ctx;
230230
}
231+
232+
template <typename T>
233+
void GemmWarpTest(int m, int n, int k, T alpha, T beta) {
234+
paddle::framework::Tensor mat_a;
235+
paddle::framework::Tensor mat_b;
236+
paddle::framework::Tensor mat_c_ref;
237+
paddle::framework::Tensor mat_c_mkl;
238+
auto* cpu_place = new paddle::platform::CPUPlace();
239+
240+
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
241+
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
242+
T* CREF = mat_c_ref.mutable_data<T>({m, n}, *cpu_place);
243+
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
244+
245+
ASSERT_EQ(mat_c_mkl.numel(), mat_c_ref.numel());
246+
for (int i = 0; i < mat_a.numel(); ++i) {
247+
A[i] = static_cast<T>(i);
248+
}
249+
for (int i = 0; i < mat_b.numel(); ++i) {
250+
B[i] = static_cast<T>(i + 1);
251+
}
252+
for (int i = 0; i < mat_c_ref.numel(); ++i) {
253+
CREF[i] = static_cast<T>(i + 2);
254+
CMKL[i] = CREF[i];
255+
}
256+
257+
// this would call gemm_warp
258+
paddle::platform::CPUDeviceContext context(*cpu_place);
259+
GetBlas<T>(context).GEMM(CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B,
260+
beta, CREF);
261+
262+
// lda,ldb,ldc follow RowMajor
263+
int lda = k;
264+
int ldb = n;
265+
int ldc = n;
266+
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
267+
CblasNoTrans, m, n, k, alpha, A, lda,
268+
B, ldb, beta, CMKL, ldc);
269+
270+
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
271+
EXPECT_FLOAT_EQ(CREF[i], CMKL[i]);
272+
}
273+
}
274+
275+
TEST(math_function, gemm_warp) {
276+
GemmWarpTest<float>(3, 2, 5, 1.f, 0.f);
277+
GemmWarpTest<float>(3, 2, 5, 2.f, 1.f);
278+
GemmWarpTest<float>(8, 5, 6, 1.f, 0.f);
279+
GemmWarpTest<float>(8, 5, 6, 2.f, 1.f);
280+
GemmWarpTest<double>(3, 2, 5, 1.0, 0.0);
281+
GemmWarpTest<double>(3, 2, 5, 2.0, 1.0);
282+
GemmWarpTest<double>(8, 5, 6, 1.0, 0.0);
283+
GemmWarpTest<double>(8, 5, 6, 2.0, 1.0);
284+
}

0 commit comments

Comments
 (0)