11#include " ../main/TensorOperation.h"
2+ #include " BaseGeneration.test.h"
23#include < catch2/catch_test_macros.hpp>
34#include < catch2/generators/catch_generators.hpp>
45#include < catch2/generators/catch_generators_range.hpp>
78#include < cmath>
89#include < cstdint>
910#include < span>
11+ #include < xtensor-blas/xlinalg.hpp>
1012#include < xtensor/containers/xtensor.hpp>
1113#include < xtensor/generators/xrandom.hpp>
1214
@@ -102,8 +104,6 @@ TEST_CASE("Test tensor operation with main kernel: gemm", "[tensor_operation][ge
102104 xt::xtensor<float , 2 > tensorOutVerify ({64 , 64 });
103105 std::copy (tensorOut.begin (), tensorOut.end (), tensorOutVerify.begin ());
104106
105- std::cout << tensorIn0[0 ] << " " << tensorIn1[0 ] << " " << tensorOutVerify[0 ] << " " << tensorOut[0 ] << std::endl;
106-
107107 mini_jit::TensorOperation tensor_op;
108108 TensorOperation::error_t err =
109109 tensor_op.setup (TensorOperation::dtype_t ::fp32, TensorOperation::prim_t ::none, TensorOperation::prim_t ::gemm,
@@ -112,14 +112,14 @@ TEST_CASE("Test tensor operation with main kernel: gemm", "[tensor_operation][ge
112112
113113 REQUIRE (err == TensorOperation::error_t ::success);
114114
115- tensor_op. execute (tensorIn0. data () , tensorIn1. data (), tensorOut. data () );
115+ // tensorOut += xt::linalg::dot (tensorIn0, tensorIn1);
116116
117- std::cout << tensorOutVerify[0 ] << " " << tensorOut[0 ] << std::endl;
117+ GenerationTest generatorTest (64 , 64 , 64 );
118+ generatorTest.naive_matmul_M_N_K_Batch (tensorIn0.data (), tensorIn1.data (), tensorOutVerify.data (), 64 , 64 , 64 , 1 , 1 );
118119
119- // TODO: Implement the verification logic for naive gemm operation
120+ tensor_op. execute (tensorIn0. data (), tensorIn1. data (), tensorOut. data ());
120121
121122 verify_tensor (tensorOutVerify.data (), tensorOut.data (), tensorOut.size ());
122- FAIL ();
123123}
124124
125125TEST_CASE (" Test tensor operation with main kernel: brgemm" , " [tensor_operation][brgemm][correctness]" )
@@ -131,8 +131,8 @@ TEST_CASE("Test tensor operation with main kernel: brgemm", "[tensor_operation][
131131 constexpr TensorOperation::exec_t exec_types[]{TensorOperation::exec_t ::prim, TensorOperation::exec_t ::prim,
132132 TensorOperation::exec_t ::prim, TensorOperation::exec_t ::prim};
133133 constexpr int64_t dim_sizes[]{64 , 64 , 64 , 64 };
134- constexpr int64_t strides_in0[]{64 , 1 , 0 , 64 };
135- constexpr int64_t strides_in1[]{64 , 0 , 64 , 1 };
134+ constexpr int64_t strides_in0[]{64 * 64 , 1 , 0 , 64 };
135+ constexpr int64_t strides_in1[]{64 * 64 , 0 , 64 , 1 };
136136 constexpr int64_t strides_out[]{0 , 1 , 64 , 0 };
137137
138138 xt::random::seed (Catch::rngSeed ());
@@ -156,7 +156,8 @@ TEST_CASE("Test tensor operation with main kernel: brgemm", "[tensor_operation][
156156
157157 tensor_op.execute (tensorIn0.data (), tensorIn1.data (), tensorOut.data ());
158158
159- // TODO: Implement the verification logic for naive brgemm operation
159+ GenerationTest generatorTest (64 , 64 , 64 , 64 );
160+ generatorTest.naive_matmul_M_N_K_Batch (tensorIn0.data (), tensorIn1.data (), tensorOutVerify.data (), 64 , 64 , 64 , 64 * 64 , 64 * 64 );
160161
161162 verify_tensor (tensorOutVerify.data (), tensorOut.data (), tensorOut.size ());
162163}
0 commit comments