@@ -159,10 +159,16 @@ mini_jit::EinsumTree::~EinsumTree()
159159
160160void mini_jit::EinsumTree::delete_tree (EinsumNode *node)
161161{
162- if (!node)
162+ if (node == nullptr )
163+ {
163164 return ;
165+ }
166+
164167 delete_tree (node->left );
165168 delete_tree (node->right );
169+ node->left = nullptr ;
170+ node->right = nullptr ;
171+
166172 if (node->type != NodeType::Leaf && node->tensor != nullptr && node != get_root ())
167173 {
168174 delete[] node->tensor ;
@@ -224,6 +230,16 @@ std::string mini_jit::EinsumTree::EinsumNode::to_string() const
224230 return mini_jit::EinsumTree::EinsumNode::_to_string (0 , " " , " " );
225231}
226232
233+ std::string mini_jit::EinsumTree::EinsumNode::name () const
234+ {
235+ std::string output = std::format (" {}" , output_dim_ids[0 ]);
236+ for (auto iDim = output_dim_ids.begin () + 1 ; iDim != output_dim_ids.end (); iDim++)
237+ {
238+ output += std::format (" _{}" , *iDim);
239+ }
240+ return output;
241+ }
242+
227243mini_jit::TensorConfig mini_jit::EinsumTree::lower_node (const EinsumNode *node)
228244{
229245 // Node has two children -> contraction
@@ -477,6 +493,9 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
477493 return error;
478494 }
479495
496+ #ifdef SAVE_JITS_TO_FILE
497+ tensor_op.write_kernel_to_file (node->name ());
498+ #endif // SAVE_JITS_TO_FILE
480499 tensor_op.execute (node->left ->tensor , nullptr , node->tensor );
481500 }
482501 else if (node->type == NodeType::Contraction)
@@ -517,6 +536,9 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
517536 return error;
518537 }
519538
539+ #ifdef SAVE_JITS_TO_FILE
540+ tensor_op.write_kernel_to_file (node->name ());
541+ #endif // SAVE_JITS_TO_FILE
520542 tensor_op.execute (node->left ->tensor , node->right ->tensor , node->tensor );
521543 }
522544 else
0 commit comments