22#define MINI_JIT_EINSUM_TREE_H
33
44#include " TensorConfig.h"
5+ #include " TensorOperation.h"
56#include < cstdint>
67#include < map>
78#include < string>
@@ -30,6 +31,23 @@ namespace mini_jit
3031 InvalidRoot = 1 ,
3132 NotEnoughInputTensors = 2 ,
3233 TooManyInputTensors = 3 ,
34+ UndefinedNode = 4 ,
35+
36+ err_wrong_dtype = 101 ,
37+ err_wrong_dimension = 102 ,
38+ err_wrong_primitive = 103 ,
39+ err_wrong_first_touch_primitive = 104 ,
40+ err_wrong_main_primitive = 105 ,
41+ err_wrong_last_touch_primitive = 106 ,
42+ err_execution_type_not_supported = 107 ,
43+ err_invalid_primitive_configuration = 108 ,
44+ err_invalid_first_touch_configuration = 109 ,
45+ err_invalid_main_configuration = 110 ,
46+ err_invalid_last_touch_configuration = 111 ,
47+ err_invalid_execution_order = 112 ,
48+ err_invalid_strides = 113 ,
49+ err_k_dimension_must_not_be_shared = 114 ,
50+ err_shared_required_for_parallel_execution = 115 ,
3351 };
3452
3553 enum class NodeType
@@ -44,22 +62,34 @@ namespace mini_jit
4462 NodeType type;
4563 float *tensor;
4664
47- // For leaf and contraction nodes
48- std::vector<int64_t > input_dims0;
49- std::vector<int64_t > input_dims1; // Only used for contraction
50-
5165 // Always filled — dims of the output tensor
52- std::vector<int64_t > output_dims ;
66+ std::vector<int64_t > output_dim_ids ;
5367
5468 // Pointers to children
5569 EinsumNode *left = nullptr ;
5670 EinsumNode *right = nullptr ;
5771
72+ /* *
73+ * Gets a string representation of the einsum tree.
74+ */
5875 std::string to_string () const ;
5976
60- int64_t get_size () const ;
77+ /* *
78+ * Get the size of the tensor represented by this node.
79+ *
80+ * @param dim_sizes A vector of dimension sizes corresponding to the output dimensions.
81+ */
82+ int64_t get_size (const std::vector<int64_t > dim_sizes) const ;
6183
6284 private:
85+ /* *
86+ * This method recursively formats the node and its children into a string.
87+ *
88+ * @param depth The current depth in the tree, used for indentation.
89+ * @param connection A string representing the connection type.
90+ * @param depthString A string representation of the current depth.
91+ * @return A formatted string representing the einsum tree.
92+ */
6393 std::string _to_string (uint depth, std::string connection, std::string depthString) const ;
6494 };
6595
@@ -69,30 +99,131 @@ namespace mini_jit
6999 const std::string tree_str;
70100 ErrorParse error_parse;
71101 ErrorExecute error_execute;
72- std::map< int , int > dim_sizes; // Maps dim ID to actual size
102+ std::vector< int64_t > dim_sizes;
73103
74104 // Parser
105+ /* *
106+ * Parses a node from the string starting at the given position in the einsum tree.
107+ * The node can be a leaf, contraction, or transposition node.
108+ *
109+ * @param pos The position in the string to start parsing from.
110+ * @param str The string containing the einsum tree representation.
111+ * @return A pointer to the parsed EinsumNode.
112+ */
75113 EinsumNode *parse_node (size_t &pos, const std::string &str);
76114
77115 // Lowering
116+ /* *
117+ * Lowers the given EinsumNode to a TensorConfig.
118+ *
119+ * @param node The EinsumNode to lower.
120+ * @return A TensorConfig representing the lowered node.
121+ */
78122 TensorConfig lower_node (const EinsumNode *node);
79- void *execute_node (EinsumNode *node);
123+
124+ /* *
125+ * Retrieves the dimension types and sizes for the given EinsumNode.
126+ *
127+ * @param node The EinsumNode for which to retrieve the dimension types and sizes.
128+ * @param id_map A map that associates dimension IDs with a fixed index.
129+ * @param dim_types A vector to store the dimension types.
130+ * @param dim_sizes A vector to store the dimension sizes.
131+ * @param number_of_k A reference to store the number of 'k' dimensions.
132+ */
133+ void get_config_dim_types_and_sizes (const mini_jit::EinsumTree::EinsumNode *node, std::map<int64_t , size_t > &id_map,
134+ std::vector<mini_jit::TensorConfig::dim_t > &dim_types, std::vector<int64_t > &dim_sizes,
135+ uint32_t &number_of_k);
136+
137+ /* *
138+ * Retrieves the strides for the given EinsumNode based on the provided dimension ID map.
139+ *
140+ * @param node The EinsumNode for which to retrieve the strides.
141+ * @param id_map A map that associates dimension IDs with a fixed index.
142+ * @return A vector of strides corresponding to the dimensions of the node.
143+ */
144+ std::vector<int64_t > get_config_strides (const EinsumNode *node, std::map<int64_t , size_t > &id_map);
145+
146+ /* *
147+ * Executes the tensor operation for the given EinsumNode.
148+ *
149+ * @param node The EinsumNode to execute.
150+ * @return An ErrorExecute enum indicating the result of the execution.
151+ */
152+ ErrorExecute execute_node (EinsumNode *node);
153+
154+ /* *
155+ * Assigns intermediate tensors to the given EinsumNode.
156+ *
157+ * @param tensors A vector of pointers to the intermediate tensors.
158+ * @param node The EinsumNode to which the tensors will be assigned.
159+ */
80160 void assign_tensor (std::vector<void *> tensors, EinsumNode *node);
81161
82162 // Helpers
83- std::vector<TensorConfig::dim_t > infer_dim_types (const std::vector<int64_t > &dims);
163+ /* *
164+ * Parses a dimension list from the string starting at the given position.
165+ * The dimension list is expected to be a comma-separated list of integers.
166+ *
167+ * @param pos The position in the string to start parsing from.
168+ * @param str The string containing the dimension list.
169+ * @return A vector of integers representing the parsed dimensions.
170+ */
84171 std::vector<int64_t > parse_dim_list (size_t &pos, const std::string &str);
85- std::vector<int64_t > compute_strides (const std::vector<int64_t > &dims);
172+
173+ /* *
174+ * Computes the strides for the given dimension IDs based on the sorted dimension sizes.
175+ *
176+ * @param dim_ids A vector of dimension IDs for which to compute the strides.
177+ * @return A vector of computed strides corresponding to the dimension IDs.
178+ */
179+ std::vector<int64_t > compute_strides (const std::vector<int64_t > &dim_ids);
180+
181+ /* *
182+ * Retrieves the output dimensions for the given dimension IDs based on the sorted dimension sizes.
183+ *
184+ * @param dim_ids A vector of dimension IDs for which to retrieve the output dimensions.
185+ * @return A vector of output dimensions corresponding to the provided dimension IDs.
186+ */
187+ std::vector<int64_t > get_output_dims (const std::vector<int64_t > &dim_ids);
188+
189+ /* *
190+ * Parses the setup error from a TensorOperation error code to an ErrorExecute enum.
191+ *
192+ * @param error The error code from TensorOperation.
193+ * @return An ErrorExecute enum representing the parsed error.
194+ */
195+ ErrorExecute parse_setup_error (TensorOperation::error_t error);
86196
87197 // Cleanup
198+ /* *
199+ * Recursively deletes the EinsumNode tree starting from the given node.
200+ */
88201 void delete_tree (EinsumNode *node);
89202
90203 public:
91- EinsumTree (const std::string &tree_str, const std::vector<int > &sorted_dim_sizes);
204+ EinsumTree (const std::string &tree_str, const std::vector<int64_t > &sorted_dim_sizes);
92205 ~EinsumTree ();
93206
207+ /* *
208+ * Parses the einsum tree string and builds the tree structure.
209+ *
210+ * @return ErrorParse indicating the result of the parsing operation.
211+ */
94212 ErrorParse parse_tree ();
213+
214+ /* *
215+ * Returns the root node of the EinsumTree.
216+ *
217+ * @return Pointer to the root EinsumNode.
218+ */
95219 EinsumNode *get_root () const ;
220+
221+ /* *
222+ * Executes the einsum operation defined by the tree.
223+ *
224+ * @param tensors A vector of pointers to the input tensors of the leaves.
225+ * @return ErrorExecute indicating the result of the execution operation.
226+ */
96227 ErrorExecute execute (std::vector<void *> tensors);
97228 };
98229}; // namespace mini_jit
0 commit comments