Skip to content

Commit 992c07f

Browse files
committed
feat: kernels to file from einsumtree
1 parent fa953cd commit 992c07f

File tree

9 files changed

+99
-1
lines changed

9 files changed

+99
-1
lines changed

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,20 @@ if(OpenMP_CXX_FOUND)
311311
target_link_libraries(tests PRIVATE OpenMP::OpenMP_CXX)
312312
endif()
313313

314+
# tests sanitized
315+
add_executable(tests_sanitized "${SOURCE_FILEPATHS}" "${TEST_FILEPATHS}")
316+
if(SAVE_JITS_TO_FILE)
317+
target_compile_definitions(tests_sanitized PUBLIC SAVE_JITS_TO_FILE)
318+
endif(SAVE_JITS_TO_FILE)
319+
320+
target_compile_options(tests_sanitized PRIVATE -g -fsanitize=float-divide-by-zero -fsanitize=bounds -fsanitize=address -fsanitize=undefined -fno-omit-frame-pointer)
321+
target_link_options(tests_sanitized PRIVATE -g -fsanitize=address -fsanitize=undefined)
322+
target_link_libraries(tests_sanitized PRIVATE Catch2::Catch2WithMain)
323+
324+
if(OpenMP_CXX_FOUND)
325+
target_link_libraries(tests_sanitized PRIVATE OpenMP::OpenMP_CXX)
326+
endif()
327+
314328
# benchmarks
315329
add_executable(benchmarks "${SOURCE_FILEPATHS}" "${BENCH_FILEPATHS}")
316330

src/main/Brgemm.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ mini_jit::Brgemm::kernel_t mini_jit::Brgemm::get_kernel() const
4747
return kernel;
4848
}
4949

50+
void mini_jit::Brgemm::write_kernel_to_file(const char *path) const
51+
{
52+
native_kernel.write(path);
53+
}
54+
5055
void mini_jit::Brgemm::fill_with_matmuls_no_batch_dim_column_major_fp32(uint32_t m, uint32_t n, uint32_t k)
5156
{
5257
// Always sort from the specific to the more general case

src/main/Brgemm.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ class mini_jit::Brgemm
6666
**/
6767
kernel_t get_kernel() const;
6868

69+
/**
70+
* @brief Writes the current kernel into a file.
71+
*
72+
* @param path The file to write the kernel to.
73+
*/
74+
void write_kernel_to_file(const char *path) const;
75+
6976
private:
7077
kernel_t kernel = nullptr;
7178
mini_jit::Kernel native_kernel;

src/main/EinsumTree.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,16 @@ mini_jit::EinsumTree::~EinsumTree()
159159

160160
void mini_jit::EinsumTree::delete_tree(EinsumNode *node)
161161
{
162-
if (!node)
162+
if (node == nullptr)
163+
{
163164
return;
165+
}
166+
164167
delete_tree(node->left);
165168
delete_tree(node->right);
169+
node->left = nullptr;
170+
node->right = nullptr;
171+
166172
if (node->type != NodeType::Leaf && node->tensor != nullptr && node != get_root())
167173
{
168174
delete[] node->tensor;
@@ -224,6 +230,16 @@ std::string mini_jit::EinsumTree::EinsumNode::to_string() const
224230
return mini_jit::EinsumTree::EinsumNode::_to_string(0, "", "");
225231
}
226232

233+
std::string mini_jit::EinsumTree::EinsumNode::name() const
234+
{
235+
std::string output = std::format("{}", output_dim_ids[0]);
236+
for (auto iDim = output_dim_ids.begin() + 1; iDim != output_dim_ids.end(); iDim++)
237+
{
238+
output += std::format("_{}", *iDim);
239+
}
240+
return output;
241+
}
242+
227243
mini_jit::TensorConfig mini_jit::EinsumTree::lower_node(const EinsumNode *node)
228244
{
229245
// Node has two children -> contraction
@@ -477,6 +493,9 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
477493
return error;
478494
}
479495

496+
#ifdef SAVE_JITS_TO_FILE
497+
tensor_op.write_kernel_to_file(node->name());
498+
#endif // SAVE_JITS_TO_FILE
480499
tensor_op.execute(node->left->tensor, nullptr, node->tensor);
481500
}
482501
else if (node->type == NodeType::Contraction)
@@ -517,6 +536,9 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
517536
return error;
518537
}
519538

539+
#ifdef SAVE_JITS_TO_FILE
540+
tensor_op.write_kernel_to_file(node->name());
541+
#endif // SAVE_JITS_TO_FILE
520542
tensor_op.execute(node->left->tensor, node->right->tensor, node->tensor);
521543
}
522544
else

src/main/EinsumTree.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ namespace mini_jit
7676
*/
7777
std::string to_string() const;
7878

79+
/**
80+
* @brief Gets the string representation of the dim ids of the node.
81+
*/
82+
std::string name() const;
83+
7984
/**
8085
* Get the size of the tensor represented by this node.
8186
*

src/main/TensorOperation.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "TensorOperation.h"
22
#include "TensorOptimization.h"
33
#include "release_assert.h"
4+
#include <format>
45
#include <iostream>
56
#include <omp.h>
67
#include <ranges>
@@ -809,4 +810,29 @@ void mini_jit::TensorOperation::execute_dimension(int64_t index_dim, char const
809810
mini_jit::TensorConfig mini_jit::TensorOperation::get_config()
810811
{
811812
return config;
813+
}
814+
815+
void mini_jit::TensorOperation::write_kernel_to_file(std::string path_no_extension) const
816+
{
817+
if (prim_first != TensorConfig::prim_t::none && std::holds_alternative<Unary>(first_touch))
818+
{
819+
std::get<Unary>(first_touch).write_kernel_to_file(std::format("{}_first_touch.bin", path_no_extension).c_str());
820+
}
821+
822+
if (prim_main != TensorConfig::prim_t::none)
823+
{
824+
if (std::holds_alternative<Brgemm>(main_kernel))
825+
{
826+
std::get<Brgemm>(main_kernel).write_kernel_to_file(std::format("{}_main.bin", path_no_extension).c_str());
827+
}
828+
else if (std::holds_alternative<Unary>(main_kernel))
829+
{
830+
std::get<Unary>(main_kernel).write_kernel_to_file(std::format("{}_main.bin", path_no_extension).c_str());
831+
}
832+
}
833+
834+
if (prim_last != TensorConfig::prim_t::none && std::holds_alternative<Unary>(last_touch))
835+
{
836+
std::get<Unary>(last_touch).write_kernel_to_file(std::format("{}_first_touch.bin", path_no_extension).c_str());
837+
}
812838
}

src/main/TensorOperation.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ namespace mini_jit
242242
* @return TensorConfig used by the Tensor operation.
243243
*/
244244
TensorConfig get_config();
245+
246+
/**
247+
* @brief Writes the current kernel into a file.
248+
*
249+
* @param path The file to write the kernel to without extension.
250+
*/
251+
void write_kernel_to_file(std::string path_no_extension) const;
245252
};
246253
}; // namespace mini_jit
247254

src/main/Unary.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,9 @@ void mini_jit::Unary::relu_unary_fp32(uint32_t m, uint32_t n, uint32_t trans_b)
104104
kernels::unary_relu(native_kernel, m, n); // logic of zero_16m_n combined with rest processing
105105
}
106106
return;
107+
}
108+
109+
void mini_jit::Unary::write_kernel_to_file(const char *path) const
110+
{
111+
native_kernel.write(path);
107112
}

src/main/Unary.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ class mini_jit::Unary
9292
* @return pointer to the generated kernel.
9393
**/
9494
kernel_t get_kernel() const;
95+
96+
/**
97+
* @brief Writes the current kernel into a file.
98+
*
99+
* @param path The file to write the kernel to.
100+
*/
101+
void write_kernel_to_file(const char *path) const;
95102
};
96103

97104
#endif

0 commit comments

Comments
 (0)