Skip to content

Commit 6d37acf

Browse files
committed
feat: finished einsum tree
1 parent 88e612c commit 6d37acf

File tree

8 files changed

+989
-176
lines changed

8 files changed

+989
-176
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ set(TEST_ARM_INSTRUCTION_FILES
250250
set(BENCH_FILES
251251
TensorOperation.bench.cpp
252252
TensorOptimization.bench.cpp
253+
EinsumTree.bench.cpp
253254
)
254255

255256
set(BENCH_KERNLES_FILES

src/main/EinsumTree.cpp

Lines changed: 258 additions & 68 deletions
Large diffs are not rendered by default.

src/main/EinsumTree.h

Lines changed: 142 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define MINI_JIT_EINSUM_TREE_H
33

44
#include "TensorConfig.h"
5+
#include "TensorOperation.h"
56
#include <cstdint>
67
#include <map>
78
#include <string>
@@ -30,6 +31,23 @@ namespace mini_jit
3031
InvalidRoot = 1,
3132
NotEnoughInputTensors = 2,
3233
TooManyInputTensors = 3,
34+
UndefinedNode = 4,
35+
36+
err_wrong_dtype = 101,
37+
err_wrong_dimension = 102,
38+
err_wrong_primitive = 103,
39+
err_wrong_first_touch_primitive = 104,
40+
err_wrong_main_primitive = 105,
41+
err_wrong_last_touch_primitive = 106,
42+
err_execution_type_not_supported = 107,
43+
err_invalid_primitive_configuration = 108,
44+
err_invalid_first_touch_configuration = 109,
45+
err_invalid_main_configuration = 110,
46+
err_invalid_last_touch_configuration = 111,
47+
err_invalid_execution_order = 112,
48+
err_invalid_strides = 113,
49+
err_k_dimension_must_not_be_shared = 114,
50+
err_shared_required_for_parallel_execution = 115,
3351
};
3452

3553
enum class NodeType
@@ -44,22 +62,34 @@ namespace mini_jit
4462
NodeType type;
4563
float *tensor;
4664

47-
// For leaf and contraction nodes
48-
std::vector<int64_t> input_dims0;
49-
std::vector<int64_t> input_dims1; // Only used for contraction
50-
5165
// Always filled — dims of the output tensor
52-
std::vector<int64_t> output_dims;
66+
std::vector<int64_t> output_dim_ids;
5367

5468
// Pointers to children
5569
EinsumNode *left = nullptr;
5670
EinsumNode *right = nullptr;
5771

72+
/**
73+
* Gets a string representation of the einsum tree.
74+
*/
5875
std::string to_string() const;
5976

60-
int64_t get_size() const;
77+
/**
78+
* Get the size of the tensor represented by this node.
79+
*
80+
* @param dim_sizes A vector of dimension sizes corresponding to the output dimensions.
81+
*/
82+
int64_t get_size(const std::vector<int64_t> dim_sizes) const;
6183

6284
private:
85+
/**
86+
* This method recursively formats the node and its children into a string.
87+
*
88+
* @param depth The current depth in the tree, used for indentation.
89+
* @param connection A string representing the connection type.
90+
* @param depthString A string representation of the current depth.
91+
* @return A formatted string representing the einsum tree.
92+
*/
6393
std::string _to_string(uint depth, std::string connection, std::string depthString) const;
6494
};
6595

@@ -69,30 +99,131 @@ namespace mini_jit
6999
const std::string tree_str;
70100
ErrorParse error_parse;
71101
ErrorExecute error_execute;
72-
std::map<int, int> dim_sizes; // Maps dim ID to actual size
102+
std::vector<int64_t> dim_sizes;
73103

74104
// Parser
105+
/**
106+
* Parses a node from the string starting at the given position in the einsum tree.
107+
* The node can be a leaf, contraction, or transposition node.
108+
*
109+
* @param pos The position in the string to start parsing from.
110+
* @param str The string containing the einsum tree representation.
111+
* @return A pointer to the parsed EinsumNode.
112+
*/
75113
EinsumNode *parse_node(size_t &pos, const std::string &str);
76114

77115
// Lowering
116+
/**
117+
* Lowers the given EinsumNode to a TensorConfig.
118+
*
119+
* @param node The EinsumNode to lower.
120+
* @return A TensorConfig representing the lowered node.
121+
*/
78122
TensorConfig lower_node(const EinsumNode *node);
79-
void *execute_node(EinsumNode *node);
123+
124+
/**
125+
* Retrieves the dimension types and sizes for the given EinsumNode.
126+
*
127+
* @param node The EinsumNode for which to retrieve the dimension types and sizes.
128+
* @param id_map A map that associates dimension IDs with a fixed index.
129+
* @param dim_types A vector to store the dimension types.
130+
* @param dim_sizes A vector to store the dimension sizes.
131+
* @param number_of_k A reference to store the number of 'k' dimensions.
132+
*/
133+
void get_config_dim_types_and_sizes(const mini_jit::EinsumTree::EinsumNode *node, std::map<int64_t, size_t> &id_map,
134+
std::vector<mini_jit::TensorConfig::dim_t> &dim_types, std::vector<int64_t> &dim_sizes,
135+
uint32_t &number_of_k);
136+
137+
/**
138+
* Retrieves the strides for the given EinsumNode based on the provided dimension ID map.
139+
*
140+
* @param node The EinsumNode for which to retrieve the strides.
141+
* @param id_map A map that associates dimension IDs with a fixed index.
142+
* @return A vector of strides corresponding to the dimensions of the node.
143+
*/
144+
std::vector<int64_t> get_config_strides(const EinsumNode *node, std::map<int64_t, size_t> &id_map);
145+
146+
/**
147+
* Executes the tensor operation for the given EinsumNode.
148+
*
149+
* @param node The EinsumNode to execute.
150+
* @return An ErrorExecute enum indicating the result of the execution.
151+
*/
152+
ErrorExecute execute_node(EinsumNode *node);
153+
154+
/**
155+
* Assigns intermediate tensors to the given EinsumNode.
156+
*
157+
* @param tensors A vector of pointers to the intermediate tensors.
158+
* @param node The EinsumNode to which the tensors will be assigned.
159+
*/
80160
void assign_tensor(std::vector<void *> tensors, EinsumNode *node);
81161

82162
// Helpers
83-
std::vector<TensorConfig::dim_t> infer_dim_types(const std::vector<int64_t> &dims);
163+
/**
164+
* Parses a dimension list from the string starting at the given position.
165+
* The dimension list is expected to be a comma-separated list of integers.
166+
*
167+
* @param pos The position in the string to start parsing from.
168+
* @param str The string containing the dimension list.
169+
* @return A vector of integers representing the parsed dimensions.
170+
*/
84171
std::vector<int64_t> parse_dim_list(size_t &pos, const std::string &str);
85-
std::vector<int64_t> compute_strides(const std::vector<int64_t> &dims);
172+
173+
/**
174+
* Computes the strides for the given dimension IDs based on the sorted dimension sizes.
175+
*
176+
* @param dim_ids A vector of dimension IDs for which to compute the strides.
177+
* @return A vector of computed strides corresponding to the dimension IDs.
178+
*/
179+
std::vector<int64_t> compute_strides(const std::vector<int64_t> &dim_ids);
180+
181+
/**
182+
* Retrieves the output dimensions for the given dimension IDs based on the sorted dimension sizes.
183+
*
184+
* @param dim_ids A vector of dimension IDs for which to retrieve the output dimensions.
185+
* @return A vector of output dimensions corresponding to the provided dimension IDs.
186+
*/
187+
std::vector<int64_t> get_output_dims(const std::vector<int64_t> &dim_ids);
188+
189+
/**
190+
* Parses the setup error from a TensorOperation error code to an ErrorExecute enum.
191+
*
192+
* @param error The error code from TensorOperation.
193+
* @return An ErrorExecute enum representing the parsed error.
194+
*/
195+
ErrorExecute parse_setup_error(TensorOperation::error_t error);
86196

87197
// Cleanup
198+
/**
199+
* Recursively deletes the EinsumNode tree starting from the given node.
200+
*/
88201
void delete_tree(EinsumNode *node);
89202

90203
public:
91-
EinsumTree(const std::string &tree_str, const std::vector<int> &sorted_dim_sizes);
204+
EinsumTree(const std::string &tree_str, const std::vector<int64_t> &sorted_dim_sizes);
92205
~EinsumTree();
93206

207+
/**
208+
* Parses the einsum tree string and builds the tree structure.
209+
*
210+
* @return ErrorParse indicating the result of the parsing operation.
211+
*/
94212
ErrorParse parse_tree();
213+
214+
/**
215+
* Returns the root node of the EinsumTree.
216+
*
217+
* @return Pointer to the root EinsumNode.
218+
*/
95219
EinsumNode *get_root() const;
220+
221+
/**
222+
* Executes the einsum operation defined by the tree.
223+
*
224+
* @param tensors A vector of pointers to the input tensors of the leaves.
225+
* @return ErrorExecute indicating the result of the execution operation.
226+
*/
96227
ErrorExecute execute(std::vector<void *> tensors);
97228
};
98229
}; // namespace mini_jit

0 commit comments

Comments
 (0)