Skip to content

Commit 911c01d

Browse files
committed
feat: testing tensor ops
1 parent 6fc67e1 commit 911c01d

File tree

4 files changed

+17
-15
lines changed

4 files changed

+17
-15
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@
100100
"NVCOMPILER",
101101
"pnacl",
102102
"relu",
103-
"xtensor"
103+
"xtensor",
104+
"linalg"
104105
],
105106

106107
// Include Paths

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ add_executable(tests "${SOURCE_FILEPATHS}" "${TEST_FILEPATHS}")
298298
if(SAVE_JITS_TO_FILE)
299299
target_compile_definitions(tests PUBLIC SAVE_JITS_TO_FILE)
300300
endif(SAVE_JITS_TO_FILE)
301-
target_link_libraries(tests PRIVATE Catch2::Catch2WithMain xtensor)
301+
target_link_libraries(tests PRIVATE Catch2::Catch2WithMain xtensor xtensor-blas)
302302

303303
add_executable(benchmarks "${SOURCE_FILEPATHS}" "${BENCH_FILEPATHS}")
304304
target_link_libraries(benchmarks benchmark::benchmark_main)

src/test/BaseGeneration.test.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,22 @@ class GenerationTest
2323
float *matrix_c_verify;
2424
mini_jit::Brgemm::kernel_t kernel = nullptr;
2525

26+
public:
2627
/**
2728
* @brief Fills the given matrix with random values.
2829
*
2930
* @param matrix The matrix to fill.
3031
* @param size The total size of the matrix.
3132
*/
32-
void fill_random_matrix(float *matrix, uint32_t size);
33+
static void fill_random_matrix(float *matrix, uint32_t size);
3334

3435
/**
3536
* @brief Fills the given matrix with counting values starting from 0.
3637
*
3738
* @param matrix The matrix to fill.
3839
* @param size The total size of the matrix.
3940
*/
40-
void fill_counting_matrix(float *matrix, uint32_t size);
41+
static void fill_counting_matrix(float *matrix, uint32_t size);
4142

4243
/**
4344
* @brief Does a naive matmul for verification usage.
@@ -61,9 +62,8 @@ class GenerationTest
6162
* @param result The actual matrix values.
6263
* @param size The total size of the matrix.
6364
*/
64-
void verify_matmul(const float *__restrict__ expected, const float *__restrict__ result, uint32_t size);
65+
static void verify_matmul(const float *__restrict__ expected, const float *__restrict__ result, uint32_t size);
6566

66-
public:
6767
GenerationTest() = delete;
6868
GenerationTest(uint32_t M, uint32_t N, uint32_t K);
6969
GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t lda, uint32_t ldb, uint32_t ldc);

src/test/TensorOperation.test.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "../main/TensorOperation.h"
2+
#include "BaseGeneration.test.h"
23
#include <catch2/catch_test_macros.hpp>
34
#include <catch2/generators/catch_generators.hpp>
45
#include <catch2/generators/catch_generators_range.hpp>
@@ -7,6 +8,7 @@
78
#include <cmath>
89
#include <cstdint>
910
#include <span>
11+
#include <xtensor-blas/xlinalg.hpp>
1012
#include <xtensor/containers/xtensor.hpp>
1113
#include <xtensor/generators/xrandom.hpp>
1214

@@ -102,8 +104,6 @@ TEST_CASE("Test tensor operation with main kernel: gemm", "[tensor_operation][ge
102104
xt::xtensor<float, 2> tensorOutVerify({64, 64});
103105
std::copy(tensorOut.begin(), tensorOut.end(), tensorOutVerify.begin());
104106

105-
std::cout << tensorIn0[0] << " " << tensorIn1[0] << " " << tensorOutVerify[0] << " " << tensorOut[0] << std::endl;
106-
107107
mini_jit::TensorOperation tensor_op;
108108
TensorOperation::error_t err =
109109
tensor_op.setup(TensorOperation::dtype_t::fp32, TensorOperation::prim_t::none, TensorOperation::prim_t::gemm,
@@ -112,14 +112,14 @@ TEST_CASE("Test tensor operation with main kernel: gemm", "[tensor_operation][ge
112112

113113
REQUIRE(err == TensorOperation::error_t::success);
114114

115-
tensor_op.execute(tensorIn0.data(), tensorIn1.data(), tensorOut.data());
115+
// tensorOut += xt::linalg::dot(tensorIn0, tensorIn1);
116116

117-
std::cout << tensorOutVerify[0] << " " << tensorOut[0] << std::endl;
117+
GenerationTest generatorTest(64, 64, 64);
118+
generatorTest.naive_matmul_M_N_K_Batch(tensorIn0.data(), tensorIn1.data(), tensorOutVerify.data(), 64, 64, 64, 1, 1);
118119

119-
// TODO: Implement the verification logic for naive gemm operation
120+
tensor_op.execute(tensorIn0.data(), tensorIn1.data(), tensorOut.data());
120121

121122
verify_tensor(tensorOutVerify.data(), tensorOut.data(), tensorOut.size());
122-
FAIL();
123123
}
124124

125125
TEST_CASE("Test tensor operation with main kernel: brgemm", "[tensor_operation][brgemm][correctness]")
@@ -131,8 +131,8 @@ TEST_CASE("Test tensor operation with main kernel: brgemm", "[tensor_operation][
131131
constexpr TensorOperation::exec_t exec_types[]{TensorOperation::exec_t::prim, TensorOperation::exec_t::prim,
132132
TensorOperation::exec_t::prim, TensorOperation::exec_t::prim};
133133
constexpr int64_t dim_sizes[]{64, 64, 64, 64};
134-
constexpr int64_t strides_in0[]{64, 1, 0, 64};
135-
constexpr int64_t strides_in1[]{64, 0, 64, 1};
134+
constexpr int64_t strides_in0[]{64 * 64, 1, 0, 64};
135+
constexpr int64_t strides_in1[]{64 * 64, 0, 64, 1};
136136
constexpr int64_t strides_out[]{0, 1, 64, 0};
137137

138138
xt::random::seed(Catch::rngSeed());
@@ -156,7 +156,8 @@ TEST_CASE("Test tensor operation with main kernel: brgemm", "[tensor_operation][
156156

157157
tensor_op.execute(tensorIn0.data(), tensorIn1.data(), tensorOut.data());
158158

159-
// TODO: Implement the verification logic for naive brgemm operation
159+
GenerationTest generatorTest(64, 64, 64, 64);
160+
generatorTest.naive_matmul_M_N_K_Batch(tensorIn0.data(), tensorIn1.data(), tensorOutVerify.data(), 64, 64, 64, 64 * 64, 64 * 64);
160161

161162
verify_tensor(tensorOutVerify.data(), tensorOut.data(), tensorOut.size());
162163
}

0 commit comments

Comments
 (0)